diff --git a/README.md b/README.md
index 218b0748..670c5ad9 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
# DataOps Data Quality TestGen
-  [](https://hub.docker.com/r/datakitchen/dataops-testgen) [](https://hub.docker.com/r/datakitchen/dataops-testgen) [](https://docs.datakitchen.io/articles/#!dataops-testgen-help/dataops-testgen-help) [](https://data-observability-slack.datakitchen.io/join)
+  [](https://hub.docker.com/r/datakitchen/dataops-testgen) [](https://hub.docker.com/r/datakitchen/dataops-testgen) [](https://docs.datakitchen.io/articles/dataops-testgen-help/dataops-testgen-help) [](https://data-observability-slack.datakitchen.io/join)
*
DataOps Data Quality TestGen, or "TestGen" for short, can help you find data issues so you can alert your users and notify your suppliers. It does this by delivering simple, fast data quality test generation and execution by data profiling, new dataset screening and hygiene review, algorithmic generation of data quality validation tests, ongoing production testing of new data refreshes, and continuous anomaly monitoring of datasets. TestGen is part of DataKitchen's Open Source Data Observability.
*
@@ -110,7 +110,7 @@ Within the virtual environment, install the TestGen package with pip.
pip install dataops-testgen
```
-Verify that the [_testgen_ command line](https://docs.datakitchen.io/articles/#!dataops-testgen-help/testgen-commands-and-details) works.
+Verify that the [_testgen_ command line](https://docs.datakitchen.io/articles/dataops-testgen-help/testgen-commands-and-details) works.
```shell
testgen --help
```
@@ -187,7 +187,7 @@ python3 dk-installer.py tg delete-demo
### Upgrade to latest version
-New releases of TestGen are announced on the `#releases` channel on [Data Observability Slack](https://data-observability-slack.datakitchen.io/join), and release notes can be found on the [DataKitchen documentation portal](https://docs.datakitchen.io/articles/#!dataops-testgen-help/testgen-release-notes/a/h1_1691719522). Use the following command to upgrade to the latest released version.
+New releases of TestGen are announced on the `#releases` channel on [Data Observability Slack](https://data-observability-slack.datakitchen.io/join), and release notes can be found on the [DataKitchen documentation portal](https://docs.datakitchen.io/articles/dataops-testgen-help/testgen-release-notes/a/h1_1691719522). Use the following command to upgrade to the latest released version.
```shell
python3 dk-installer.py tg upgrade
@@ -203,7 +203,7 @@ python3 dk-installer.py tg delete
### Access the _testgen_ CLI
-The [_testgen_ command line](https://docs.datakitchen.io/articles/#!dataops-testgen-help/testgen-commands-and-details) can be accessed within the running container.
+The [_testgen_ command line](https://docs.datakitchen.io/articles/dataops-testgen-help/testgen-commands-and-details) can be accessed within the running container.
```shell
docker compose exec engine bash
@@ -232,7 +232,7 @@ We recommend you start by going through the [Data Observability Overview Demo](h
For support requests, [join the Data Observability Slack](https://data-observability-slack.datakitchen.io/join) π and post on the `#support` channel.
### Connect to your database
-Follow [these instructions](https://docs.datakitchen.io/articles/#!dataops-testgen-help/connect-your-database) to improve the quality of data in your database.
+Follow [these instructions](https://docs.datakitchen.io/articles/dataops-testgen-help/connect-your-database) to improve the quality of data in your database.
### Community
Talk and learn with other data practitioners who are building with DataKitchen. Share knowledge, get help, and contribute to our open-source project.
diff --git a/deploy/charts/README.md b/deploy/charts/README.md
index a8656084..c081df4b 100644
--- a/deploy/charts/README.md
+++ b/deploy/charts/README.md
@@ -40,9 +40,14 @@ set can be easily used on the first install and future upgrades.
The following configuration is recommended for experimental installations, but
you're free to adjust it for your needs. The next installation steps assumes
-that a file named tg-values.yaml exists with this configuration.
+that a file named `tg-values.yaml` exists with this configuration.
```yaml
+image:
+
+ # DataOps TestGen version to be installed / upgraded
+ tag: v4
+
testgen:
# Password that will be assigned to the 'admin' user during the database preparation
@@ -54,10 +59,18 @@ testgen:
# Whether to run the SSL certificate verifications when connecting to DataOps Observability
observabilityVerifySsl: false
-image:
+ # (Optional) E-mail and SMTP configurations for enabling the email notifications
+ emailNotifications:
- # DataOps TestGen version to be installed / upgraded
- tag: v4.0
+ # The email address that notifications will be sent from
+ fromAddress:
+
+ # SMTP configuration for sending emails
+ smtp:
+ endpoint:
+ port:
+ username:
+ password:
```
# Installing
diff --git a/deploy/charts/testgen-app/Chart.yaml b/deploy/charts/testgen-app/Chart.yaml
index 01b2c072..ac7f2e5d 100644
--- a/deploy/charts/testgen-app/Chart.yaml
+++ b/deploy/charts/testgen-app/Chart.yaml
@@ -15,7 +15,7 @@ type: application
# This is the chart version. This version number should be incremented each time you make changes
# to the chart and its templates, including the app version.
# Versions are expected to follow Semantic Versioning (https://semver.org/)
-version: 1.0.1
+version: 1.1.0
# This is the version number of the application being deployed. This version number should be
# incremented each time you make changes to the application. Versions are not expected to
diff --git a/deploy/charts/testgen-app/templates/_environment.yaml b/deploy/charts/testgen-app/templates/_environment.yaml
index c329f75c..5bf01711 100644
--- a/deploy/charts/testgen-app/templates/_environment.yaml
+++ b/deploy/charts/testgen-app/templates/_environment.yaml
@@ -31,6 +31,18 @@
value: {{ .Values.testgen.trustTargetDatabaseCertificate | ternary "yes" "no" | quote }}
- name: TG_EXPORT_TO_OBSERVABILITY_VERIFY_SSL
value: {{ .Values.testgen.observabilityVerifySsl | ternary "yes" "no" | quote }}
+{{- with .Values.testgen.emailNotifications }}
+- name: TG_SMTP_ENDPOINT
+ value: {{ .smtp.endpoint | quote }}
+- name: TG_SMTP_PORT
+ value: {{ .smtp.port | quote }}
+- name: TG_SMTP_USERNAME
+ value: {{ .smtp.username | quote }}
+- name: TG_SMTP_PASSWORD
+ value: {{ .smtp.password | quote }}
+- name: TG_EMAIL_FROM_ADDRESS
+ value: {{ .fromAddress | quote }}
+{{- end -}}
{{- end -}}
{{- define "testgen.hookEnvironment" -}}
diff --git a/deploy/charts/testgen-app/values.yaml b/deploy/charts/testgen-app/values.yaml
index d958ae09..8018a4cc 100644
--- a/deploy/charts/testgen-app/values.yaml
+++ b/deploy/charts/testgen-app/values.yaml
@@ -1,12 +1,12 @@
# Default values for testgen.
testgen:
- databaseHost: "postgresql"
+ databaseHost: "postgres"
databaseName: "datakitchen"
databaseSchema: "tgapp"
databaseUser: "postgres"
databasePasswordSecret:
- name: "postgresql"
+ name: "services-secret"
key: "postgres-password"
authSecrets:
create: true
@@ -15,6 +15,13 @@ testgen:
uiPassword:
trustTargetDatabaseCertificate: false
observabilityVerifySsl: true
+ emailNotifications:
+ fromAddress:
+ smtp:
+ endpoint:
+ port:
+ username:
+ password:
labels:
cliHooks:
diff --git a/deploy/charts/testgen-services/Chart.lock b/deploy/charts/testgen-services/Chart.lock
deleted file mode 100644
index 72fc8358..00000000
--- a/deploy/charts/testgen-services/Chart.lock
+++ /dev/null
@@ -1,6 +0,0 @@
-dependencies:
-- name: postgresql
- repository: https://charts.bitnami.com/bitnami
- version: 16.3.0
-digest: sha256:92eb2890efc38c617fa56144f4f54c0ac1ee11818f6b00860ec00a87be48f249
-generated: "2025-02-24T09:05:32.15558-05:00"
diff --git a/deploy/charts/testgen-services/Chart.yaml b/deploy/charts/testgen-services/Chart.yaml
index 8e98b830..954994c9 100644
--- a/deploy/charts/testgen-services/Chart.yaml
+++ b/deploy/charts/testgen-services/Chart.yaml
@@ -12,18 +12,13 @@ description: A Helm chart for Kubernetes
# pipeline. Library charts do not define any templates and therefore cannot be deployed.
type: application
-dependencies:
- - name: postgresql
- version: 16.3.0
- repository: https://charts.bitnami.com/bitnami
-
# This is the chart version. This version number should be incremented each time you make changes
# to the chart and its templates, including the app version.
# Versions are expected to follow Semantic Versioning (https://semver.org/)
-version: 0.1.1
+version: 1.0.0
# This is the version number of the application being deployed. This version number should be
# incremented each time you make changes to the application. Versions are not expected to
# follow Semantic Versioning. They should reflect the version the application is using.
# It is recommended to use it with quotes.
-appVersion: "1.16.0"
+appVersion: "4"
diff --git a/deploy/charts/testgen-services/templates/_helpers.tpl b/deploy/charts/testgen-services/templates/_helpers.tpl
new file mode 100644
index 00000000..de964984
--- /dev/null
+++ b/deploy/charts/testgen-services/templates/_helpers.tpl
@@ -0,0 +1,62 @@
+{{/*
+Expand the name of the chart.
+*/}}
+{{- define "testgen-services.name" -}}
+{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
+{{- end }}
+
+{{/*
+Create a default fully qualified app name.
+We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
+If release name contains chart name it will be used as a full name.
+*/}}
+{{- define "testgen-services.fullname" -}}
+{{- if .Values.fullnameOverride }}
+{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
+{{- else }}
+{{- $name := default .Chart.Name .Values.nameOverride }}
+{{- if contains $name .Release.Name }}
+{{- .Release.Name | trunc 63 | trimSuffix "-" }}
+{{- else }}
+{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
+{{- end }}
+{{- end }}
+{{- end }}
+
+{{/*
+Create chart name and version as used by the chart label.
+*/}}
+{{- define "testgen-services.chart" -}}
+{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
+{{- end }}
+
+{{/*
+Common labels
+*/}}
+{{- define "testgen-services.labels" -}}
+helm.sh/chart: {{ include "testgen-services.chart" . }}
+{{ include "testgen-services.selectorLabels" . }}
+{{- if .Chart.AppVersion }}
+app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
+{{- end }}
+app.kubernetes.io/managed-by: {{ .Release.Service }}
+{{- end }}
+
+{{/*
+Selector labels
+*/}}
+{{- define "testgen-services.selectorLabels" -}}
+app.kubernetes.io/name: {{ include "testgen-services.name" . }}
+app.kubernetes.io/instance: {{ .Release.Name }}
+{{- end }}
+
+{{/*
+Create the name of the service account to use
+*/}}
+{{- define "testgen-services.serviceAccountName" -}}
+{{- if .Values.serviceAccount.create }}
+{{- default (include "testgen-services.fullname" .) .Values.serviceAccount.name }}
+{{- else }}
+{{- default "default" .Values.serviceAccount.name }}
+{{- end }}
+{{- end }}
diff --git a/deploy/charts/testgen-services/templates/secret.yaml b/deploy/charts/testgen-services/templates/secret.yaml
new file mode 100644
index 00000000..eb894fd6
--- /dev/null
+++ b/deploy/charts/testgen-services/templates/secret.yaml
@@ -0,0 +1,9 @@
+{{- if .Values.secret.create }}
+apiVersion: v1
+kind: Secret
+metadata:
+ name: {{ .Values.secret.name }}
+type: Opaque
+data:
+ {{ .Values.postgres.auth.passwordKey }}: {{ randAlphaNum 24 | b64enc }}
+{{- end }}
diff --git a/deploy/charts/testgen-services/templates/service.yaml b/deploy/charts/testgen-services/templates/service.yaml
new file mode 100644
index 00000000..f53c5c82
--- /dev/null
+++ b/deploy/charts/testgen-services/templates/service.yaml
@@ -0,0 +1,15 @@
+apiVersion: v1
+kind: Service
+metadata:
+ name: {{ include "testgen-services.fullname" . }}
+ labels:
+ {{- include "testgen-services.labels" . | nindent 4 }}
+spec:
+ type: {{ .Values.postgres.service.type }}
+ ports:
+ - port: {{ .Values.postgres.service.port }}
+ targetPort: 5432
+ protocol: TCP
+ name: postgres
+ selector:
+ {{- include "testgen-services.selectorLabels" . | nindent 4 }}
diff --git a/deploy/charts/testgen-services/templates/serviceaccount.yaml b/deploy/charts/testgen-services/templates/serviceaccount.yaml
new file mode 100644
index 00000000..b7873994
--- /dev/null
+++ b/deploy/charts/testgen-services/templates/serviceaccount.yaml
@@ -0,0 +1,13 @@
+{{- if .Values.serviceAccount.create -}}
+apiVersion: v1
+kind: ServiceAccount
+metadata:
+ name: {{ include "testgen-services.serviceAccountName" . }}
+ labels:
+ {{- include "testgen-services.labels" . | nindent 4 }}
+ {{- with .Values.serviceAccount.annotations }}
+ annotations:
+ {{- toYaml . | nindent 4 }}
+ {{- end }}
+automountServiceAccountToken: {{ .Values.serviceAccount.automount }}
+{{- end }}
diff --git a/deploy/charts/testgen-services/templates/statefulset.yaml b/deploy/charts/testgen-services/templates/statefulset.yaml
new file mode 100644
index 00000000..608ed230
--- /dev/null
+++ b/deploy/charts/testgen-services/templates/statefulset.yaml
@@ -0,0 +1,75 @@
+apiVersion: apps/v1
+kind: StatefulSet
+metadata:
+ name: {{ include "testgen-services.fullname" . }}
+ labels:
+ {{- include "testgen-services.labels" . | nindent 4 }}
+spec:
+ replicas: {{ .Values.postgres.replicaCount }}
+ selector:
+ matchLabels:
+ {{- include "testgen-services.selectorLabels" . | nindent 6 }}
+ template:
+ metadata:
+ {{- with .Values.podAnnotations }}
+ annotations:
+ {{- toYaml . | nindent 8 }}
+ {{- end }}
+ labels:
+ {{- include "testgen-services.labels" . | nindent 8 }}
+ {{- with .Values.podLabels }}
+ {{- toYaml . | nindent 8 }}
+ {{- end }}
+ spec:
+ {{- with .Values.postgres.imagePullSecrets }}
+ imagePullSecrets:
+ {{- toYaml . | nindent 8 }}
+ {{- end }}
+ serviceAccountName: {{ include "testgen-services.serviceAccountName" . }}
+ securityContext:
+ {{- toYaml .Values.podSecurityContext | nindent 8 }}
+ containers:
+ - name: {{ .Chart.Name }}-postgres
+ securityContext:
+ {{- toYaml .Values.securityContext | nindent 12 }}
+ image: "{{ .Values.postgres.image.repository }}:{{ .Values.postgres.image.tag }}"
+ imagePullPolicy: {{ .Values.postgres.image.pullPolicy }}
+ ports:
+ - containerPort: 5432
+ name: postgres
+ env:
+ - name: POSTGRES_USER
+ value: {{ .Values.postgres.auth.user }}
+ - name: POSTGRES_DB
+ value: {{ .Values.postgres.auth.database }}
+ - name: POSTGRES_PASSWORD
+ valueFrom:
+ secretKeyRef:
+ name: {{ .Values.secret.name }}
+ key: {{ .Values.postgres.auth.passwordKey }}
+ volumeMounts:
+ - name: data
+ mountPath: /var/lib/postgresql/data
+ readinessProbe:
+ exec:
+ command: ["pg_isready", "-U", "{{ .Values.postgres.auth.user }}"]
+ initialDelaySeconds: 5
+ periodSeconds: 10
+ livenessProbe:
+ exec:
+ command: ["pg_isready", "-U", "{{ .Values.postgres.auth.user }}"]
+ initialDelaySeconds: 30
+ periodSeconds: 10
+ resources:
+ {{- toYaml .Values.postgres.resources | nindent 12 }}
+ volumeClaimTemplates:
+ - metadata:
+ name: data
+ spec:
+ accessModes: ["ReadWriteOnce"]
+ resources:
+ requests:
+ storage: {{ .Values.postgres.storage.size }}
+ {{- if .Values.postgres.storage.storageClass }}
+ storageClassName: "{{ .Values.postgres.storage.storageClass }}"
+ {{- end }}
diff --git a/deploy/charts/testgen-services/values.yaml b/deploy/charts/testgen-services/values.yaml
index af7ca7be..90c44114 100644
--- a/deploy/charts/testgen-services/values.yaml
+++ b/deploy/charts/testgen-services/values.yaml
@@ -2,13 +2,69 @@
# This is a YAML-formatted file.
# Declare variables to be passed into your templates.
-postgresql:
- fullnameOverride: postgresql
- auth:
- database: "datakitchen"
+nameOverride: ""
+fullnameOverride: "postgres"
+
+serviceAccount:
+ # Specifies whether a service account should be created
+ create: true
+ # Automatically mount a ServiceAccount's API credentials?
+ automount: true
+ # Annotations to add to the service account
+ annotations: {}
+ # The name of the service account to use.
+ # If not set and create is true, a name is generated using the fullname template
+ name: ""
+
+podAnnotations: {}
+podLabels: {}
+
+podSecurityContext: {}
+ # fsGroup: 2000
+
+securityContext: {}
+ # capabilities:
+ # drop:
+ # - ALL
+ # readOnlyRootFilesystem: true
+ # runAsNonRoot: true
+ # runAsUser: 1000
+
+secret:
+ create: true
+ name: services-secret
+
+postgres:
+ replicaCount: 1
+
+ service:
+ type: ClusterIP
+ port: 5432
+
image:
- repository: bitnamilegacy/postgresql
+ repository: postgres
+ pullPolicy: IfNotPresent
+ tag: "14.1-alpine"
+
+ imagePullSecrets: []
+
+ auth:
+ user: postgres
+ database: datakitchen
+ passwordKey: postgres-password
+
+ storage:
+ size: 1Gi
+ storageClass: ""
-global:
- security:
- allowInsecureImages: true
+ resources: {}
+ # We usually recommend not to specify default resources and to leave this as a conscious
+ # choice for the user. This also increases chances charts run on environments with little
+ # resources, such as Minikube. If you do want to specify resources, uncomment the following
+ # lines, adjust them as necessary, and remove the curly braces after 'resources:'.
+ # limits:
+ # cpu: 100m
+ # memory: 128Mi
+ # requests:
+ # cpu: 100m
+ # memory: 128Mi
diff --git a/deploy/install_linuxodbc.sh b/deploy/install_linuxodbc.sh
index 9f585221..e0f4080d 100755
--- a/deploy/install_linuxodbc.sh
+++ b/deploy/install_linuxodbc.sh
@@ -1,7 +1,8 @@
#!/usr/bin/env sh
# From: https://learn.microsoft.com/en-us/sql/connect/odbc/linux-mac/installing-the-microsoft-odbc-driver-for-sql-server
-# modifications: Added --non-interactive and --no-cache flags, removed sudo, added aarch64 as an alias for arm64
+# modifications: added --non-interactive and --no-cache flags, removed sudo, added aarch64 as an alias for arm64,
+# added certificate installation, isolated folder, replaced gpg --verify with gpgv
architecture="unsupported"
@@ -19,18 +20,30 @@ if [ "unsupported" = "$architecture" ]; then
exit 1
fi
-#Download the desired package(s)
-curl -O https://download.microsoft.com/download/7/6/d/76de322a-d860-4894-9945-f0cc5d6a45f8/msodbcsql18_18.4.1.1-1_$architecture.apk
-curl -O https://download.microsoft.com/download/7/6/d/76de322a-d860-4894-9945-f0cc5d6a45f8/mssql-tools18_18.4.1.1-1_$architecture.apk
-
-#(Optional) Verify signature, if 'gpg' is missing install it using 'apk add gnupg':
-curl -O https://download.microsoft.com/download/7/6/d/76de322a-d860-4894-9945-f0cc5d6a45f8/msodbcsql18_18.4.1.1-1_$architecture.sig
-curl -O https://download.microsoft.com/download/7/6/d/76de322a-d860-4894-9945-f0cc5d6a45f8/mssql-tools18_18.4.1.1-1_$architecture.sig
-
-curl https://packages.microsoft.com/keys/microsoft.asc | gpg --import -
-gpg --verify msodbcsql18_18.4.1.1-1_$architecture.sig msodbcsql18_18.4.1.1-1_$architecture.apk
-gpg --verify mssql-tools18_18.4.1.1-1_$architecture.sig mssql-tools18_18.4.1.1-1_$architecture.apk
-
-#Install the package(s)
-apk add --no-cache --non-interactive --allow-untrusted msodbcsql18_18.4.1.1-1_$architecture.apk
-apk add --no-cache --non-interactive --allow-untrusted mssql-tools18_18.4.1.1-1_$architecture.apk
+(
+ set -e
+ tmpdir="$(mktemp -d)"
+ trap 'rm -rf "$tmpdir"' EXIT
+ cd "$tmpdir"
+
+ # Recent Alpine versions lacks the Microsoft certificate chain, so we download and install it manually
+ curl -fsSL -o cert.crt https://www.microsoft.com/pkiops/certs/Microsoft%20TLS%20G2%20ECC%20CA%20OCSP%2002.crt
+ openssl x509 -inform DER -in cert.crt -out /usr/local/share/ca-certificates/microsoft_tls_g2_ecc_ocsp_02.pem
+ update-ca-certificates
+
+ # Download the desired packages
+ curl -O https://download.microsoft.com/download/9dcab408-e0d4-4571-a81a-5a0951e3445f/msodbcsql18_18.6.1.1-1_$architecture.apk
+ curl -O https://download.microsoft.com/download/b60bb8b6-d398-4819-9950-2e30cf725fb0/mssql-tools18_18.6.1.1-1_$architecture.apk
+
+ # Verify signature, if 'gpg' is missing install it using 'apk add gnupg':
+ curl -O https://download.microsoft.com/download/9dcab408-e0d4-4571-a81a-5a0951e3445f/msodbcsql18_18.6.1.1-1_$architecture.sig
+ curl -O https://download.microsoft.com/download/b60bb8b6-d398-4819-9950-2e30cf725fb0/mssql-tools18_18.6.1.1-1_$architecture.sig
+
+ curl https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor > microsoft.gpg
+ gpgv --keyring ./microsoft.gpg msodbcsql18_*.sig msodbcsql18_*.apk
+ gpgv --keyring ./microsoft.gpg mssql-tools18_*.sig mssql-tools18_*.apk
+
+ # Install the packages
+ apk add --no-cache --allow-untrusted msodbcsql18_18.6.1.1-1_$architecture.apk
+ apk add --no-cache --allow-untrusted mssql-tools18_18.6.1.1-1_$architecture.apk
+)
diff --git a/deploy/testgen-base.dockerfile b/deploy/testgen-base.dockerfile
index de45fcf7..d758a03f 100644
--- a/deploy/testgen-base.dockerfile
+++ b/deploy/testgen-base.dockerfile
@@ -1,4 +1,4 @@
-FROM python:3.12-alpine3.22
+FROM python:3.12-alpine3.23
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
@@ -14,24 +14,22 @@ RUN apk update && apk upgrade && apk add --no-cache \
cmake \
musl-dev \
gfortran \
- linux-headers=6.14.2-r0 \
- # Tools needed for installing the MSSQL ODBC drivers \
+ linux-headers=6.16.12-r0 \
+ # Tools needed for installing the MSSQL ODBC drivers
curl \
gpg \
+ gpgv \
+ openssl \
# Additional libraries needed and their dev counterparts. We add both so that we can remove
# the *-dev later, keeping the libraries
- openblas=0.3.28-r0 \
- openblas-dev=0.3.28-r0 \
- unixodbc=2.3.12-r0 \
- unixodbc-dev=2.3.12-r0 \
+ openblas=0.3.30-r2 \
+ openblas-dev=0.3.30-r2 \
+ unixodbc=2.3.14-r0 \
+ unixodbc-dev=2.3.14-r0 \
+ libarrow=21.0.0-r4 \
+ apache-arrow-dev=21.0.0-r4 \
# Pinned versions for security
- xz=5.8.1-r0
-
-RUN apk add --no-cache \
- --repository https://dl-cdn.alpinelinux.org/alpine/v3.21/community \
- --repository https://dl-cdn.alpinelinux.org/alpine/v3.21/main \
- libarrow=18.1.0-r0 \
- apache-arrow-dev=18.1.0-r0
+ xz=5.8.2-r0
COPY --chmod=775 ./deploy/install_linuxodbc.sh /tmp/dk/install_linuxodbc.sh
RUN /tmp/dk/install_linuxodbc.sh
@@ -39,6 +37,10 @@ RUN /tmp/dk/install_linuxodbc.sh
# Install TestGen's main project empty pyproject.toml to install (and cache) the dependencies first
COPY ./pyproject.toml /tmp/dk/pyproject.toml
RUN mkdir /dk
+
+# Upgrading pip for security
+RUN python3 -m pip install --upgrade pip==26.0
+
RUN python3 -m pip install --prefix=/dk /tmp/dk
RUN apk del \
@@ -46,9 +48,12 @@ RUN apk del \
g++ \
make \
cmake \
+ curl \
musl-dev \
gfortran \
gpg \
+ gpgv \
+ openssl \
linux-headers \
openblas-dev \
unixodbc-dev \
diff --git a/deploy/testgen.dockerfile b/deploy/testgen.dockerfile
index 4ff2ff94..5c4bb933 100644
--- a/deploy/testgen.dockerfile
+++ b/deploy/testgen.dockerfile
@@ -1,4 +1,4 @@
-ARG TESTGEN_BASE_LABEL=v9
+ARG TESTGEN_BASE_LABEL=v11
FROM datakitchen/dataops-testgen-base:${TESTGEN_BASE_LABEL} AS release-image
@@ -17,6 +17,8 @@ COPY . /tmp/dk/
RUN python3 -m pip install --prefix=/dk /tmp/dk
RUN rm -Rf /tmp/dk
+RUN tg-patch-streamlit
+
RUN addgroup -S testgen && adduser -S testgen -G testgen
# Streamlit has to be able to write to these dirs
diff --git a/pyproject.toml b/pyproject.toml
index 7267ca28..57285fc2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "dataops-testgen"
-version = "4.39.2"
+version = "5.0.1"
description = "DataKitchen's Data Quality DataOps TestGen"
authors = [
{ "name" = "DataKitchen, Inc.", "email" = "info@datakitchen.io" },
@@ -41,7 +41,7 @@ dependencies = [
"requests_extensions==1.1.3",
"numpy==1.26.4",
"pandas==2.1.4",
- "streamlit==1.46.1",
+ "streamlit==1.53.0",
"streamlit-extras==0.3.0",
"streamlit-aggrid==0.3.4.post3",
"plotly_express==0.4.1",
@@ -62,9 +62,11 @@ dependencies = [
"cron-descriptor==2.0.5",
"pybars3==0.9.7",
"azure-identity==1.25.1",
+ "statsmodels==0.14.6",
+ "holidays~=0.89",
# Pinned to match the manually compiled libs or for security
- "pyarrow==18.1.0",
+ "pyarrow==21.0.0",
"matplotlib==3.9.2",
"scipy==1.14.1",
"jinja2==3.1.6",
@@ -92,12 +94,13 @@ release = [
[project.entry-points.console_scripts]
testgen = "testgen.__main__:cli"
+tg-patch-streamlit = "testgen.ui.scripts.patch_streamlit:patch"
[project.urls]
"Source Code" = "https://github.com/DataKitchen/dataops-testgen"
"Bug Tracker" = "https://github.com/DataKitchen/dataops-testgen/issues"
-"Documentation" = "https://docs.datakitchen.io/articles/#!dataops-testgen-help/dataops-testgen-help"
-"Release Notes" = "https://docs.datakitchen.io/articles/#!dataops-testgen-help/testgen-release-notes"
+"Documentation" = "https://docs.datakitchen.io/articles/dataops-testgen-help/dataops-testgen-help"
+"Release Notes" = "https://docs.datakitchen.io/articles/dataops-testgen-help/testgen-release-notes"
"Slack" = "https://data-observability-slack.datakitchen.io/join"
"Homepage" = "https://example.com"
@@ -107,6 +110,7 @@ include-package-data = true
[tool.setuptools.package-data]
"*" = ["*.toml", "*.sql", "*.yaml"]
"testgen.template" = ["*.sql", "*.yaml", "**/*.sql", "**/*.yaml"]
+"testgen.ui.static" = ["**/*.js", "**/*.css", "**/*.woff2"]
"testgen.ui.assets" = ["*.svg", "*.png", "*.js", "*.css", "*.ico", "flavors/*.svg"]
"testgen.ui.components.frontend" = ["*.html", "**/*.js", "**/*.css", "**/*.woff2", "**/*.svg"]
@@ -274,3 +278,19 @@ push = false
"pyproject.toml" = [
'version = "{version}"',
]
+
+[[tool.streamlit.component.components]]
+name = "table_group_wizard"
+asset_dir = "ui/components/frontend/js"
+
+[[tool.streamlit.component.components]]
+name = "edit_monitor_settings"
+asset_dir = "ui/components/frontend/js"
+
+[[tool.streamlit.component.components]]
+name = "table_monitoring_trends"
+asset_dir = "ui/components/frontend/js"
+
+[[tool.streamlit.component.components]]
+name = "edit_table_monitors"
+asset_dir = "ui/components/frontend/js"
diff --git a/setup.py b/setup.py
new file mode 100644
index 00000000..16bcbedc
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,22 @@
+import os
+
+from setuptools import setup
+from setuptools.command.build_py import build_py
+
+THIS_DIR = os.path.dirname(os.path.abspath(__file__))
+ROOT_TOML = os.path.abspath(os.path.join(THIS_DIR, "pyproject.toml"))
+
+class CustomBuildPy(build_py):
+ def run(self):
+ super().run()
+ target_toml = os.path.join(self.build_lib, "testgen", "pyproject.toml")
+ if os.path.exists(ROOT_TOML):
+ os.makedirs(os.path.dirname(target_toml), exist_ok=True)
+ self.copy_file(ROOT_TOML, target_toml)
+
+
+setup(
+ cmdclass={
+ "build_py": CustomBuildPy,
+ },
+)
diff --git a/testgen/__main__.py b/testgen/__main__.py
index 8463ab4f..62ae21c0 100644
--- a/testgen/__main__.py
+++ b/testgen/__main__.py
@@ -10,7 +10,6 @@
from click.core import Context
from testgen import settings
-from testgen.commands.run_generate_tests import run_test_gen_queries
from testgen.commands.run_get_entities import (
run_get_results,
run_get_test_suite,
@@ -29,10 +28,11 @@
from testgen.commands.run_launch_db_config import run_launch_db_config
from testgen.commands.run_observability_exporter import run_observability_exporter
from testgen.commands.run_profiling import run_profiling
-from testgen.commands.run_quick_start import run_quick_start, run_quick_start_increment
+from testgen.commands.run_quick_start import run_monitor_increment, run_quick_start, run_quick_start_increment
from testgen.commands.run_test_execution import run_test_execution
from testgen.commands.run_test_metadata_exporter import run_test_metadata_exporter
from testgen.commands.run_upgrade_db_config import get_schema_revision, is_db_revision_up_to_date, run_upgrade_db_config
+from testgen.commands.test_generation import run_monitor_generation, run_test_generation
from testgen.common import (
configure_logging,
display_service,
@@ -57,7 +57,7 @@
APP_MODULES = ["ui", "scheduler"]
VERSION_DATA = version_service.get_version()
-
+CHILDREN_POLL_INTERVAL = 10
@dataclass
class Configuration:
@@ -133,18 +133,25 @@ def run_profile(table_group_id: str):
@cli.command("run-test-generation", help="Generates or refreshes the tests for a table group.")
+@click.option(
+ "-t",
+ "--test-suite-id",
+ required=False,
+ type=click.STRING,
+ help="ID of the test suite to generate. Use a test_suite_id shown in list-test-suites.",
+)
@click.option(
"-tg",
"--table-group-id",
help="The identifier for the table group used during a profile run. Use a table_group_id shown in list-table-groups.",
- required=True,
+ required=False,
type=click.STRING,
)
@click.option(
"-ts",
"--test-suite-key",
help="The identifier for a test suite. Use a test_suite_key shown in list-test-suites.",
- required=True,
+ required=False,
type=click.STRING,
)
@click.option(
@@ -152,16 +159,37 @@ def run_profile(table_group_id: str):
"--generation-set",
help="A defined subset of tests to generate for your purpose. Use a generation_set defined for your project.",
required=False,
- default=None,
+ default="Standard",
)
-@pass_configuration
-def run_test_generation(configuration: Configuration, table_group_id: str, test_suite_key: str, generation_set: str):
- LOG.info("CurrentStep: Generate Tests - Main Procedure")
- message = run_test_gen_queries(table_group_id, test_suite_key, generation_set)
- LOG.info("Current Step: Generate Tests - Main Procedure Complete")
+@with_database_session
+def run_generation(test_suite_id: str | None = None, table_group_id: str | None = None, test_suite_key: str | None = None, generation_set: str | None = None):
+ click.echo(f"run-test-generation for suite: {test_suite_id or test_suite_key}")
+ # For backward compatibility
+ if not test_suite_id:
+ test_suites = TestSuite.select_minimal_where(
+ TestSuite.table_groups_id == table_group_id,
+ TestSuite.test_suite == test_suite_key,
+ )
+ if test_suites:
+ test_suite_id = test_suites[0].id
+ message = run_test_generation(test_suite_id, generation_set)
click.echo("\n" + message)
+@cli.command("run-monitor-generation", help="Generates or refreshes the monitors for a table group.")
+@click.option(
+ "-t",
+ "--test-suite-id",
+ required=True,
+ type=click.STRING,
+ help="ID of the monitor suite to generate",
+)
+@with_database_session
+def generate_monitors(test_suite_id: str):
+ click.echo(f"run-monitor-generation for suite: {test_suite_id}")
+ run_monitor_generation(test_suite_id, ["Freshness_Trend", "Volume_Trend", "Schema_Drift"])
+
+
@register_scheduler_job
@cli.command("run-tests", help="Performs tests defined for a test suite.")
@click.option(
@@ -201,6 +229,22 @@ def run_tests(test_suite_id: str | None = None, project_key: str | None = None,
click.echo("\n" + message)
+@register_scheduler_job
+@cli.command("run-monitors", help="Performs tests defined for a monitor suite.")
+@click.option(
+ "-t",
+ "--test-suite-id",
+ required=True,
+ type=click.STRING,
+ help="ID of the monitor suite to run.",
+)
+@with_database_session
+def run_monitors(test_suite_id: str):
+ click.echo(f"run-monitors for suite: {test_suite_id}")
+ message = run_test_execution(test_suite_id)
+ click.echo("\n" + message)
+
+
@cli.command("list-profiles", help="Lists all profile runs for a table group.")
@click.option(
"-tg",
@@ -384,7 +428,7 @@ def quick_start(
click.echo("loading initial data")
run_quick_start_increment(0)
now_date = datetime.now(UTC)
- time_delta = timedelta(days=-30) # 1 month ago
+ time_delta = timedelta(days=-35) # before the first monitor iteration (~34 days back)
table_group_id = "0ea85e17-acbe-47fe-8394-9970725ad37d"
test_suite_id = "9df7489d-92b3-49f9-95ca-512160d7896f"
@@ -392,8 +436,8 @@ def quick_start(
message = run_profiling(table_group_id, run_date=now_date + time_delta)
click.echo("\n" + message)
- LOG.info(f"run-test-generation with table_group_id: {table_group_id} test_suite: {settings.DEFAULT_TEST_SUITE_KEY}")
- message = run_test_gen_queries(table_group_id, settings.DEFAULT_TEST_SUITE_KEY)
+ LOG.info(f"run-test-generation with test_suite_id: {test_suite_id}")
+ message = with_database_session(run_test_generation)(test_suite_id, "Standard")
click.echo("\n" + message)
run_test_execution(test_suite_id, run_date=now_date + time_delta)
@@ -405,6 +449,22 @@ def quick_start(
run_quick_start_increment(iteration)
run_test_execution(test_suite_id, run_date=run_date)
+ monitor_iterations = 68 # ~5 weeks
+ monitor_interval = timedelta(hours=12)
+ monitor_test_suite_id = "823a1fef-9b6d-48d5-9d0f-2db9812cc318"
+ # Round down to nearest 12-hour mark (12:00 AM or 12:00 PM UTC)
+ now = datetime.now(UTC)
+ nearest_12h_mark = now.replace(hour=12 if now.hour >= 12 else 0, minute=0, second=0, microsecond=0)
+ monitor_run_date = nearest_12h_mark - monitor_interval * (monitor_iterations - 1)
+ weekday_morning_count = 0
+ for iteration in range(1, monitor_iterations + 1):
+ click.echo(f"Running monitor iteration: {iteration} / {monitor_iterations}")
+ if monitor_run_date.weekday() < 5 and monitor_run_date.hour < 12:
+ weekday_morning_count += 1
+ run_monitor_increment(monitor_run_date, iteration, weekday_morning_count)
+ run_test_execution(monitor_test_suite_id, run_date=monitor_run_date)
+ monitor_run_date += monitor_interval
+
click.echo("Quick start has successfully finished.")
@@ -646,7 +706,8 @@ def run_ui():
use_ssl = os.path.isfile(settings.SSL_CERT_FILE) and os.path.isfile(settings.SSL_KEY_FILE)
- patch_streamlit.patch(force=True)
+ if settings.IS_DEBUG:
+ patch_streamlit.patch(dev=True)
@with_database_session
def init_ui():
@@ -671,6 +732,7 @@ def init_ui():
"--browser.gatherUsageStats=false",
"--client.showErrorDetails=none",
"--client.toolbarMode=minimal",
+ "--server.enableStaticServing=true",
f"--server.sslCertFile={settings.SSL_CERT_FILE}" if use_ssl else "",
f"--server.sslKeyFile={settings.SSL_KEY_FILE}" if use_ssl else "",
"--",
@@ -715,8 +777,20 @@ def term_children(signum, _):
signal.signal(signal.SIGINT, term_children)
signal.signal(signal.SIGTERM, term_children)
- for child in children:
- child.wait()
+ terminating = False
+ while children:
+ try:
+ children[0].wait(CHILDREN_POLL_INTERVAL)
+ except subprocess.TimeoutExpired:
+ pass
+
+ for child in children:
+ if child.poll() is not None:
+ children.remove(child)
+ if not terminating:
+ terminating = True
+ term_children(signal.SIGTERM, None)
+
if __name__ == "__main__":
diff --git a/testgen/commands/queries/execute_tests_query.py b/testgen/commands/queries/execute_tests_query.py
index d71572b6..90d67859 100644
--- a/testgen/commands/queries/execute_tests_query.py
+++ b/testgen/commands/queries/execute_tests_query.py
@@ -1,17 +1,28 @@
import dataclasses
from collections.abc import Iterable
-from datetime import datetime
+from datetime import date, datetime
from typing import TypedDict
from uuid import UUID
+import pandas as pd
+
from testgen.common import read_template_sql_file
from testgen.common.clean_sql import concat_columns
from testgen.common.database.database_service import get_flavor_service, get_tg_schema, replace_params
+from testgen.common.freshness_service import (
+ count_excluded_minutes,
+ get_schedule_params,
+ is_excluded_day,
+ resolve_holiday_dates,
+)
from testgen.common.models.connection import Connection
+from testgen.common.models.scheduler import JobSchedule
from testgen.common.models.table_group import TableGroup
from testgen.common.models.test_definition import TestRunType, TestScope
from testgen.common.models.test_run import TestRun
+from testgen.common.models.test_suite import TestSuite
from testgen.common.read_file import replace_templated_functions
+from testgen.utils import to_sql_timestamp
@dataclasses.dataclass
@@ -46,10 +57,12 @@ class TestExecutionDef(InputParameters):
table_name: str
column_name: str
skip_errors: int
+ history_calculation: str
custom_query: str
+ prediction: dict | str | None
run_type: TestRunType
test_scope: TestScope
- template_name: str
+ template: str
measure: str
test_operator: str
test_condition: str
@@ -86,14 +99,27 @@ class TestExecutionSQL:
"result_measure",
)
- def __init__(self, connection: Connection, table_group: TableGroup, test_run: TestRun):
+ def __init__(self, connection: Connection, table_group: TableGroup, test_suite: TestSuite, test_run: TestRun):
self.connection = connection
self.table_group = table_group
+ self.test_suite = test_suite
self.test_run = test_run
- self.run_date = test_run.test_starttime.strftime("%Y-%m-%d %H:%M:%S")
+ self.run_date = test_run.test_starttime
self.flavor = connection.sql_flavor
self.flavor_service = get_flavor_service(self.flavor)
+ self._exclude_weekends = bool(self.test_suite.predict_exclude_weekends)
+ self._holiday_dates: set[date] | None = None
+ self._schedule_tz: str | None = None
+ if test_suite.is_monitor:
+ schedule = JobSchedule.get(JobSchedule.kwargs["test_suite_id"].astext == str(test_suite.id))
+ self._schedule_tz = schedule.cron_tz or "UTC" if schedule else None
+ if test_suite.holiday_codes_list:
+ self._holiday_dates = resolve_holiday_dates(
+ test_suite.holiday_codes_list,
+ pd.DatetimeIndex([datetime(self.run_date.year - 1, 1, 1), datetime(self.run_date.year + 1, 12, 31)]),
+ )
+
def _get_input_parameters(self, test_def: TestExecutionDef) -> str:
return "; ".join(
f"{field.name}={getattr(test_def, field.name)}"
@@ -104,9 +130,10 @@ def _get_input_parameters(self, test_def: TestExecutionDef) -> str:
def _get_params(self, test_def: TestExecutionDef | None = None) -> dict:
quote = self.flavor_service.quote_character
params = {
+ "TABLE_GROUPS_ID": self.table_group.id,
"TEST_SUITE_ID": self.test_run.test_suite_id,
"TEST_RUN_ID": self.test_run.id,
- "RUN_DATE": self.run_date,
+ "RUN_DATE": to_sql_timestamp(self.run_date),
"SQL_FLAVOR": self.flavor,
"VARCHAR_TYPE": self.flavor_service.varchar_type,
"QUOTE": quote,
@@ -118,22 +145,24 @@ def _get_params(self, test_def: TestExecutionDef | None = None) -> dict:
"TEST_DEFINITION_ID": test_def.id,
"APP_SCHEMA_NAME": get_tg_schema(),
"SCHEMA_NAME": test_def.schema_name,
- "TABLE_GROUPS_ID": self.table_group.id,
"TABLE_NAME": test_def.table_name,
"COLUMN_NAME": f"{quote}{test_def.column_name or ''}{quote}",
"COLUMN_NAME_NO_QUOTES": test_def.column_name,
"CONCAT_COLUMNS": concat_columns(test_def.column_name, self.null_value) if test_def.column_name else "",
"SKIP_ERRORS": test_def.skip_errors or 0,
+ "CUSTOM_QUERY": test_def.custom_query,
"BASELINE_CT": test_def.baseline_ct,
"BASELINE_UNIQUE_CT": test_def.baseline_unique_ct,
"BASELINE_VALUE": test_def.baseline_value,
"BASELINE_VALUE_CT": test_def.baseline_value_ct,
- "THRESHOLD_VALUE": test_def.threshold_value,
+ "THRESHOLD_VALUE": test_def.threshold_value or 0,
"BASELINE_SUM": test_def.baseline_sum,
"BASELINE_AVG": test_def.baseline_avg,
"BASELINE_SD": test_def.baseline_sd,
- "LOWER_TOLERANCE": test_def.lower_tolerance,
- "UPPER_TOLERANCE": test_def.upper_tolerance,
+ "LOWER_TOLERANCE": "NULL" if test_def.lower_tolerance in (None, "") else test_def.lower_tolerance,
+ "UPPER_TOLERANCE": "NULL" if test_def.upper_tolerance in (None, "") else test_def.upper_tolerance,
+ # SUBSET_CONDITION should be replaced after CUSTOM_QUERY
+ # since the latter may contain the former
"SUBSET_CONDITION": test_def.subset_condition or "1=1",
"GROUPBY_NAMES": test_def.groupby_names,
"HAVING_CONDITION": f"HAVING {test_def.having_condition}" if test_def.having_condition else "",
@@ -146,10 +175,35 @@ def _get_params(self, test_def: TestExecutionDef | None = None) -> dict:
"MATCH_GROUPBY_NAMES": test_def.match_groupby_names,
"CONCAT_MATCH_GROUPBY": concat_columns(test_def.match_groupby_names, self.null_value) if test_def.match_groupby_names else "",
"MATCH_HAVING_CONDITION": f"HAVING {test_def.match_having_condition}" if test_def.match_having_condition else "",
- "CUSTOM_QUERY": test_def.custom_query,
"COLUMN_TYPE": test_def.column_type,
"INPUT_PARAMETERS": self._get_input_parameters(test_def),
})
+
+ # Freshness exclusion params β computed per test at execution time
+ if test_def.test_type == "Freshness_Trend" and test_def.baseline_sum:
+ sched = get_schedule_params(test_def.prediction)
+ has_exclusions = self._exclude_weekends or sched.excluded_days or sched.window_start is not None
+ if has_exclusions:
+ last_update = pd.Timestamp(test_def.baseline_sum)
+ excluded = int(count_excluded_minutes(
+ last_update, self.run_date, self._exclude_weekends, self._holiday_dates,
+ tz=self._schedule_tz, excluded_days=sched.excluded_days,
+ window_start=sched.window_start, window_end=sched.window_end,
+ ))
+ is_excl = 1 if is_excluded_day(
+ pd.Timestamp(self.run_date), self._exclude_weekends, self._holiday_dates,
+ tz=self._schedule_tz, excluded_days=sched.excluded_days,
+ window_start=sched.window_start, window_end=sched.window_end,
+ ) else 0
+ params["EXCLUDED_MINUTES"] = excluded
+ params["IS_EXCLUDED_DAY"] = is_excl
+ else:
+ params["EXCLUDED_MINUTES"] = 0
+ params["IS_EXCLUDED_DAY"] = 0
+ else:
+ params["EXCLUDED_MINUTES"] = 0
+ params["IS_EXCLUDED_DAY"] = 0
+
return params
def _get_query(
@@ -171,6 +225,14 @@ def _get_query(
query = query.replace(":", "\\:")
return query, None if no_bind else params
+
+ def has_schema_changes(self) -> tuple[dict]:
+ # Runs on App database
+ return self._get_query("has_schema_changes.sql")
+
+ def get_errored_autogen_monitors(self) -> tuple[str, dict]:
+ # Runs on App database
+ return self._get_query("get_errored_autogen_monitors.sql")
def get_active_test_definitions(self) -> tuple[dict]:
# Runs on App database
@@ -212,21 +274,28 @@ def disable_invalid_test_definitions(self) -> tuple[str, dict]:
# Runs on App database
return self._get_query("disable_invalid_test_definitions.sql")
- def update_historic_thresholds(self) -> tuple[str, dict]:
+ def update_history_calc_thresholds(self) -> tuple[str, dict]:
# Runs on App database
- return self._get_query("update_historic_thresholds.sql")
+ return self._get_query("update_history_calc_thresholds.sql")
def run_query_test(self, test_def: TestExecutionDef) -> tuple[str, dict]:
# Runs on Target database
- folder = "generic" if test_def.template_name.endswith("_generic.sql") else self.flavor
- return self._get_query(
- test_def.template_name,
- f"flavors/{folder}/exec_query_tests",
- no_bind=True,
- # Final replace in CUSTOM_QUERY
- extra_params={"DATA_SCHEMA": test_def.schema_name},
- test_def=test_def,
- )
+ if test_def.template.startswith("@"):
+ folder = "generic" if test_def.template.endswith("_generic.sql") else self.flavor
+ return self._get_query(
+ test_def.template,
+ f"flavors/{folder}/exec_query_tests",
+ no_bind=True,
+ # Final replace in CUSTOM_QUERY
+ extra_params={"DATA_SCHEMA": test_def.schema_name},
+ test_def=test_def,
+ )
+ else:
+ query = test_def.template
+ params = self._get_params(test_def)
+ params.update({"DATA_SCHEMA": test_def.schema_name})
+ query = replace_params(query, params)
+ return query, params
def aggregate_cat_tests(
self,
@@ -246,9 +315,18 @@ def aggregate_cat_tests(
measure = replace_templated_functions(measure, self.flavor)
td.measure_expression = f"COALESCE(CAST({measure} AS {varchar_type}) {concat_operator} '|', '{self.null_value}|')"
- condition = replace_params(f"{td.measure}{td.test_operator}{td.test_condition}", params)
- condition = replace_templated_functions(condition, self.flavor)
- td.condition_expression = f"CASE WHEN {condition} THEN '0,' ELSE '1,' END"
+ # For prediction mode, return -1 during training period
+ if td.history_calculation == "PREDICT" and (td.lower_tolerance in (None, "") or td.upper_tolerance in (None, "")):
+ td.condition_expression = "'-1,'"
+ else:
+ condition = (
+ f"{td.measure} {td.test_operator} {td.test_condition}"
+ if "BETWEEN" in td.test_operator
+ else f"{td.measure}{td.test_operator}{td.test_condition}"
+ )
+ condition = replace_params(condition, params)
+ condition = replace_templated_functions(condition, self.flavor)
+ td.condition_expression = f"CASE WHEN {condition} THEN '0,' ELSE '1,' END"
aggregate_queries: list[tuple[str, None]] = []
aggregate_test_defs: list[list[TestExecutionDef]] = []
diff --git a/testgen/commands/queries/generate_tests_query.py b/testgen/commands/queries/generate_tests_query.py
deleted file mode 100644
index cece2d3e..00000000
--- a/testgen/commands/queries/generate_tests_query.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import logging
-from datetime import UTC, datetime
-from typing import ClassVar, TypedDict
-
-from testgen.common import CleanSQL, read_template_sql_file
-from testgen.common.database.database_service import get_flavor_service, replace_params
-from testgen.common.read_file import get_template_files
-
-LOG = logging.getLogger("testgen")
-
-class GenTestParams(TypedDict):
- test_type: str
- selection_criteria: str
- default_parm_columns: str
- default_parm_values: str
-
-
-class CDeriveTestsSQL:
- run_date = ""
- project_code = ""
- connection_id = ""
- table_groups_id = ""
- data_schema = ""
- test_suite = ""
- test_suite_id = ""
- generation_set = ""
- as_of_date = ""
- sql_flavor = ""
- gen_test_params: ClassVar[GenTestParams] = {}
-
- _use_clean = False
-
- def __init__(self, flavor):
- self.sql_flavor = flavor
- self.flavor_service = get_flavor_service(flavor)
-
- today = datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S")
- self.run_date = today
- self.as_of_date = today
-
- def _get_params(self) -> dict:
- return {
- **{key.upper(): value for key, value in self.gen_test_params.items()},
- "PROJECT_CODE": self.project_code,
- "SQL_FLAVOR": self.sql_flavor,
- "CONNECTION_ID": self.connection_id,
- "TABLE_GROUPS_ID": self.table_groups_id,
- "RUN_DATE": self.run_date,
- "TEST_SUITE": self.test_suite,
- "TEST_SUITE_ID": self.test_suite_id,
- "GENERATION_SET": self.generation_set,
- "AS_OF_DATE": self.as_of_date,
- "DATA_SCHEMA": self.data_schema,
- "QUOTE": self.flavor_service.quote_character,
- }
-
- def _get_query(self, template_file_name: str, sub_directory: str | None = "generation") -> tuple[str, dict]:
- query = read_template_sql_file(template_file_name, sub_directory)
- params = self._get_params()
- query = replace_params(query, params)
- if self._use_clean:
- query = CleanSQL(query)
- return query, params
-
- def GetInsertTestSuiteSQL(self) -> tuple[str, dict]:
- # Runs on App database
- return self._get_query("gen_insert_test_suite.sql")
-
- def GetTestTypesSQL(self) -> tuple[str, dict]:
- # Runs on App database
- return self._get_query("gen_standard_test_type_list.sql")
-
- def GetTestDerivationQueriesAsList(self, template_directory: str) -> list[tuple[str, dict]]:
- # Runs on App database
- generic_template_directory = template_directory
- flavor_template_directory = f"flavors.{self.sql_flavor}.{template_directory}"
-
- query_templates = {}
- try:
- for query_file in get_template_files(r"^.*sql$", generic_template_directory):
- query_templates[query_file.name] = generic_template_directory
- except:
- LOG.debug(
- f"query template '{generic_template_directory}' directory does not exist",
- exc_info=True,
- stack_info=True,
- )
-
- try:
- for query_file in get_template_files(r"^.*sql$", flavor_template_directory):
- query_templates[query_file.name] = flavor_template_directory
- except:
- LOG.debug(
- f"query template '{generic_template_directory}' directory does not exist",
- exc_info=True,
- stack_info=True,
- )
-
- queries = []
- for filename, sub_directory in query_templates.items():
- queries.append(self._get_query(filename, sub_directory=sub_directory))
-
- return queries
-
- def GetTestQueriesFromGenericFile(self) -> tuple[str, dict]:
- # Runs on App database
- return self._get_query("gen_standard_tests.sql")
-
- def GetDeleteOldTestsQuery(self) -> tuple[str, dict]:
- # Runs on App database
- return self._get_query("gen_delete_old_tests.sql")
diff --git a/testgen/commands/queries/profiling_query.py b/testgen/commands/queries/profiling_query.py
index a5fd7ba0..c1ec78fe 100644
--- a/testgen/commands/queries/profiling_query.py
+++ b/testgen/commands/queries/profiling_query.py
@@ -1,14 +1,14 @@
import dataclasses
-import re
from uuid import UUID
from testgen.commands.queries.refresh_data_chars_query import ColumnChars
from testgen.common import read_template_sql_file, read_template_yaml_file
-from testgen.common.database.database_service import replace_params
+from testgen.common.database.database_service import process_conditionals, replace_params
from testgen.common.models.connection import Connection
from testgen.common.models.profiling_run import ProfilingRun
from testgen.common.models.table_group import TableGroup
from testgen.common.read_file import replace_templated_functions
+from testgen.utils import to_sql_timestamp
@dataclasses.dataclass
@@ -58,7 +58,7 @@ def __init__(self, connection: Connection, table_group: TableGroup, profiling_ru
self.connection = connection
self.table_group = table_group
self.profiling_run = profiling_run
- self.run_date = profiling_run.profiling_starttime.strftime("%Y-%m-%d %H:%M:%S")
+ self.run_date = profiling_run.profiling_starttime
self.flavor = connection.sql_flavor
self._profiling_template: dict = None
@@ -68,7 +68,7 @@ def _get_params(self, column_chars: ColumnChars | None = None, table_sampling: T
"CONNECTION_ID": self.connection.connection_id,
"TABLE_GROUPS_ID": self.table_group.id,
"PROFILE_RUN_ID": self.profiling_run.id,
- "RUN_DATE": self.run_date,
+ "RUN_DATE": to_sql_timestamp(self.run_date),
"SQL_FLAVOR": self.flavor,
"DATA_SCHEMA": self.table_group.table_group_schema,
"PROFILE_ID_COLUMN_MASK": self.table_group.profile_id_column_mask,
@@ -106,7 +106,7 @@ def _get_query(
params = {}
if query:
- query = self._process_conditionals(query, extra_params)
+ query = process_conditionals(query, extra_params)
params.update(self._get_params(column_chars, table_sampling))
if extra_params:
params.update(extra_params)
@@ -116,32 +116,6 @@ def _get_query(
return query, params
- def _process_conditionals(self, query: str, extra_params: dict | None = None) -> str:
- re_pattern = re.compile(r"^--\s+TG-(IF|ELSE|ENDIF)(?:\s+(\w+))?\s*$")
- condition = None
- updated_query = []
- for line in query.splitlines(True):
- if re_match := re_pattern.match(line):
- match re_match.group(1):
- case "IF" if condition is None and (variable := re_match.group(2)) is not None:
- result = extra_params.get(variable)
- if result is None:
- result = getattr(self, variable, None)
- condition = bool(result)
- case "ELSE" if condition is not None:
- condition = not condition
- case "ENDIF" if condition is not None:
- condition = None
- case _:
- raise ValueError("Template conditional misused")
- elif condition is not False:
- updated_query.append(line)
-
- if condition is not None:
- raise ValueError("Template conditional misused")
-
- return "".join(updated_query)
-
def _get_profiling_template(self) -> dict:
if not self._profiling_template:
self._profiling_template = read_template_yaml_file(
diff --git a/testgen/commands/queries/refresh_data_chars_query.py b/testgen/commands/queries/refresh_data_chars_query.py
index 9ef02506..1df6e994 100644
--- a/testgen/commands/queries/refresh_data_chars_query.py
+++ b/testgen/commands/queries/refresh_data_chars_query.py
@@ -6,7 +6,7 @@
from testgen.common.database.database_service import get_flavor_service, replace_params
from testgen.common.models.connection import Connection
from testgen.common.models.table_group import TableGroup
-from testgen.utils import chunk_queries
+from testgen.utils import chunk_queries, to_sql_timestamp
@dataclasses.dataclass
@@ -20,6 +20,8 @@ class ColumnChars:
db_data_type: str = None
is_decimal: bool = False
approx_record_ct: int = None
+ # This should not default to 0 since we don't always retrieve actual row counts
+ # UI relies on the null value to know that the approx_record_ct should be displayed instead
record_ct: int = None
@@ -132,7 +134,7 @@ def get_staging_data_chars(self, data_chars: list[ColumnChars], run_date: dateti
return [
[
self.table_group.id,
- run_date,
+ to_sql_timestamp(run_date),
column.schema_name,
column.table_name,
column.column_name,
@@ -146,9 +148,9 @@ def get_staging_data_chars(self, data_chars: list[ColumnChars], run_date: dateti
for column in data_chars
]
- def update_data_chars(self, run_date: str) -> list[tuple[str, dict]]:
+ def update_data_chars(self, run_date: datetime) -> list[tuple[str, dict]]:
# Runs on App database
- params = {"RUN_DATE": run_date}
+ params = {"RUN_DATE": to_sql_timestamp(run_date)}
return [
self._get_query("data_chars_update.sql", extra_params=params),
self._get_query("data_chars_staging_delete.sql", extra_params=params),
diff --git a/testgen/commands/run_generate_tests.py b/testgen/commands/run_generate_tests.py
deleted file mode 100644
index 71b48491..00000000
--- a/testgen/commands/run_generate_tests.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import logging
-
-from testgen import settings
-from testgen.commands.queries.generate_tests_query import CDeriveTestsSQL
-from testgen.common import execute_db_queries, fetch_dict_from_db, get_test_generation_params, set_target_db_params
-from testgen.common.mixpanel_service import MixpanelService
-from testgen.common.models import with_database_session
-from testgen.common.models.connection import Connection
-
-LOG = logging.getLogger("testgen")
-
-
-@with_database_session
-def run_test_gen_queries(table_group_id: str, test_suite: str, generation_set: str | None = None):
- if table_group_id is None:
- raise ValueError("Table Group ID was not specified")
-
- LOG.info("CurrentStep: Assigning Connection Parameters")
- connection = Connection.get_by_table_group(table_group_id)
- set_target_db_params(connection.__dict__)
-
- clsTests = CDeriveTestsSQL(connection.sql_flavor)
-
- LOG.info(f"CurrentStep: Retrieving General Parameters for Test Suite {test_suite}")
- params = get_test_generation_params(table_group_id, test_suite)
-
-
- # Set static parms
- clsTests.project_code = params["project_code"]
- clsTests.test_suite = test_suite
- clsTests.generation_set = generation_set if generation_set is not None else ""
- clsTests.test_suite_id = params["test_suite_id"] if params["test_suite_id"] else ""
- clsTests.connection_id = str(connection.connection_id)
- clsTests.table_groups_id = table_group_id
- clsTests.data_schema = params["table_group_schema"]
- if params["profiling_as_of_date"] is not None:
- clsTests.as_of_date = params["profiling_as_of_date"].strftime("%Y-%m-%d %H:%M:%S")
-
- if params["test_suite_id"]:
- clsTests.test_suite_id = params["test_suite_id"]
- else:
- LOG.info("CurrentStep: Creating new Test Suite")
- insert_ids, _ = execute_db_queries([clsTests.GetInsertTestSuiteSQL()])
- clsTests.test_suite_id = insert_ids[0]
-
- LOG.info("CurrentStep: Compiling Test Gen Queries")
-
- lstFunnyTemplateQueries = clsTests.GetTestDerivationQueriesAsList("gen_funny_cat_tests")
- lstQueryTemplateQueries = clsTests.GetTestDerivationQueriesAsList("gen_query_tests")
- lstGenericTemplateQueries = []
-
- # Delete old Tests
- deleteQuery = clsTests.GetDeleteOldTestsQuery()
-
- # Retrieve test_types as parms from list of dictionaries: test_type, selection_criteria, default_parm_columns,
- # default_parm_values
- lstTestTypes = fetch_dict_from_db(*clsTests.GetTestTypesSQL())
-
- if lstTestTypes is None:
- raise ValueError("Test Type Parameters not found")
- elif (
- lstTestTypes[0]["test_type"] == ""
- or lstTestTypes[0]["selection_criteria"] == ""
- or lstTestTypes[0]["default_parm_columns"] == ""
- or lstTestTypes[0]["default_parm_values"] == ""
- ):
- raise ValueError("Test Type parameters not correctly set")
-
- lstGenericTemplateQueries = []
- for dctTestParms in lstTestTypes:
- clsTests.gen_test_params = dctTestParms
- lstGenericTemplateQueries.append(clsTests.GetTestQueriesFromGenericFile())
-
- LOG.info("TestGen CAT Queries were compiled")
-
- # Make sure delete, then generic templates run before the funny templates
- lstQueries = [deleteQuery, *lstGenericTemplateQueries, *lstFunnyTemplateQueries, *lstQueryTemplateQueries]
-
- if lstQueries:
- LOG.info("Running Test Generation Template Queries")
- execute_db_queries(lstQueries)
- message = "Test generation completed successfully."
- else:
- message = "No TestGen Queries were compiled."
-
- MixpanelService().send_event(
- "generate-tests",
- source=settings.ANALYTICS_JOB_SOURCE,
- sql_flavor=clsTests.sql_flavor,
- generation_set=clsTests.generation_set,
- )
-
- return message
diff --git a/testgen/commands/run_launch_db_config.py b/testgen/commands/run_launch_db_config.py
index 83da5d09..0d926fbe 100644
--- a/testgen/commands/run_launch_db_config.py
+++ b/testgen/commands/run_launch_db_config.py
@@ -53,7 +53,6 @@ def _get_params_mapping() -> dict:
"TABLE_GROUPS_NAME": settings.DEFAULT_TABLE_GROUPS_NAME,
"TEST_SUITE": settings.DEFAULT_TEST_SUITE_KEY,
"TEST_SUITE_DESCRIPTION": settings.DEFAULT_TEST_SUITE_DESCRIPTION,
- "MONITOR_TEST_SUITE": settings.DEFAULT_MONITOR_TEST_SUITE_KEY,
"MAX_THREADS": settings.PROJECT_CONNECTION_MAX_THREADS,
"MAX_QUERY_CHARS": settings.PROJECT_CONNECTION_MAX_QUERY_CHAR,
"OBSERVABILITY_API_URL": settings.OBSERVABILITY_API_URL,
diff --git a/testgen/commands/run_observability_exporter.py b/testgen/commands/run_observability_exporter.py
index b8f966b9..71179e9d 100644
--- a/testgen/commands/run_observability_exporter.py
+++ b/testgen/commands/run_observability_exporter.py
@@ -15,6 +15,7 @@
execute_db_queries,
fetch_dict_from_db,
)
+from testgen.common.models import with_database_session
from testgen.common.models.test_suite import TestSuite
LOG = logging.getLogger("testgen")
@@ -268,11 +269,13 @@ def _get_input_parameters(input_parameters):
is_first = False
elif len(items) == item_number: # is last
value = item
- ret.append({"name": name.strip(), "value": value.strip()})
+ if value.strip():
+ ret.append({"name": name.strip(), "value": value.strip()})
else:
words = item.split(",")
value = ",".join(words[:-1]) # everything but the last word
- ret.append({"name": name.strip(), "value": value.strip()})
+ if value.strip():
+ ret.append({"name": name.strip(), "value": value.strip()})
name = words[-1] # the last word is the next name
return ret
@@ -309,6 +312,7 @@ def export_test_results(test_suite_id):
mark_exported_results(test_suite_id, updated_ids)
+@with_database_session
def run_observability_exporter(project_code, test_suite):
LOG.info("CurrentStep: Observability Export - Test Results")
test_suites = TestSuite.select_minimal_where(
diff --git a/testgen/commands/run_profiling.py b/testgen/commands/run_profiling.py
index 3764f584..c97ec695 100644
--- a/testgen/commands/run_profiling.py
+++ b/testgen/commands/run_profiling.py
@@ -9,10 +9,9 @@
from testgen.commands.queries.profiling_query import HygieneIssueType, ProfilingSQL, TableSampling
from testgen.commands.queries.refresh_data_chars_query import ColumnChars
from testgen.commands.queries.rollup_scores_query import RollupScoresSQL
-from testgen.commands.run_generate_tests import run_test_gen_queries
from testgen.commands.run_refresh_data_chars import run_data_chars_refresh
from testgen.commands.run_refresh_score_cards_results import run_refresh_score_cards_results
-from testgen.commands.run_test_execution import run_test_execution_in_background
+from testgen.commands.test_generation import run_monitor_generation, run_test_generation
from testgen.common import (
execute_db_queries,
fetch_dict_from_db,
@@ -114,11 +113,8 @@ def run_profiling(table_group_id: str | UUID, username: str | None = None, run_d
profiling_run.save()
send_profiling_run_notifications(profiling_run)
-
_rollup_profiling_scores(profiling_run, table_group)
-
- if bool(table_group.monitor_test_suite_id) and not table_group.last_complete_profile_run_id:
- _generate_monitor_tests(table_group_id, table_group.monitor_test_suite_id)
+ _generate_tests(table_group)
finally:
MixpanelService().send_event(
"run-profiling",
@@ -254,6 +250,7 @@ def update_frequency_progress(progress: ThreadedProgress) -> None:
LOG.info("Updating profiling results with frequency analysis and deleting staging")
execute_db_queries(sql_generator.update_frequency_analysis_results())
except Exception as e:
+ LOG.exception("Error running frequency analysis")
profiling_run.set_progress("freq_analysis", "Warning", error=f"Error encountered. {get_exception_message(e)}")
else:
if error_data:
@@ -294,6 +291,7 @@ def _run_hygiene_issue_detection(sql_generator: ProfilingSQL) -> None:
]
)
except Exception as e:
+ LOG.exception("Error detecting hygiene issues")
profiling_run.set_progress("hygiene_issues", "Warning", error=f"Error encountered. {get_exception_message(e)}")
else:
profiling_run.set_progress("hygiene_issues", "Completed")
@@ -315,14 +313,24 @@ def _rollup_profiling_scores(profiling_run: ProfilingRun, table_group: TableGrou
@with_database_session
-def _generate_monitor_tests(table_group_id: str, test_suite_id: str) -> None:
- try:
- monitor_test_suite = TestSuite.get(test_suite_id)
- if not monitor_test_suite:
- LOG.info("Skipping test generation on missing monitor test suite")
- else:
- LOG.info("Generating monitor tests")
- run_test_gen_queries(table_group_id, monitor_test_suite.test_suite, "Monitor")
- run_test_execution_in_background(test_suite_id)
- except Exception:
- LOG.exception("Error generating monitor tests")
+def _generate_tests(table_group: TableGroup) -> None:
+ is_first_profile_run = not table_group.last_complete_profile_run_id
+
+ if bool(table_group.monitor_test_suite_id):
+ monitor_suite = TestSuite.get(table_group.monitor_test_suite_id)
+ try:
+ run_monitor_generation(
+ table_group.monitor_test_suite_id,
+ # Only Freshness depends on profiling results
+ ["Freshness_Trend"],
+ # Insert for new tables only, if user disabled regeneration
+ mode="upsert" if is_first_profile_run or monitor_suite.monitor_regenerate_freshness else "insert",
+ )
+ except Exception:
+ LOG.exception("Error generating Freshness monitors")
+
+ if is_first_profile_run and bool(table_group.default_test_suite_id):
+ try:
+ run_test_generation(table_group.default_test_suite_id, "Standard")
+ except Exception:
+ LOG.exception(f"Error generating tests for test suite: {table_group.default_test_suite_id}")
diff --git a/testgen/commands/run_quick_start.py b/testgen/commands/run_quick_start.py
index 5c9ea325..f1885c69 100644
--- a/testgen/commands/run_quick_start.py
+++ b/testgen/commands/run_quick_start.py
@@ -1,24 +1,31 @@
import logging
+import math
+import random
+from datetime import datetime
+from typing import Any
import click
from testgen import settings
from testgen.commands.run_launch_db_config import get_app_db_params_mapping, run_launch_db_config
+from testgen.commands.test_generation import run_monitor_generation
from testgen.common.credentials import get_tg_schema
from testgen.common.database.database_service import (
+ apply_params,
create_database,
execute_db_queries,
- replace_params,
set_target_db_params,
)
from testgen.common.database.flavor.flavor_service import ConnectionParams
from testgen.common.models import with_database_session
from testgen.common.models.scores import ScoreDefinition
+from testgen.common.models.settings import PersistedSetting
from testgen.common.models.table_group import TableGroup
+from testgen.common.notifications.base import smtp_configured
from testgen.common.read_file import read_template_sql_file
LOG = logging.getLogger("testgen")
-
+random.seed(42)
def _get_max_date(iteration: int):
if iteration == 0:
@@ -85,7 +92,7 @@ def _prepare_connection_to_target_database(params_mapping):
set_target_db_params(connection_params)
-def _get_params_mapping(iteration: int = 0) -> dict:
+def _get_settings_params_mapping() -> dict:
return {
"TESTGEN_ADMIN_USER": settings.DATABASE_ADMIN_USER,
"TESTGEN_ADMIN_PASSWORD": settings.DATABASE_ADMIN_PASSWORD,
@@ -96,6 +103,12 @@ def _get_params_mapping(iteration: int = 0) -> dict:
"PROJECT_DB_HOST": settings.PROJECT_DATABASE_HOST,
"PROJECT_DB_PORT": settings.PROJECT_DATABASE_PORT,
"SQL_FLAVOR": settings.PROJECT_SQL_FLAVOR,
+ }
+
+
+def _get_quick_start_params_mapping(iteration: int = 0) -> dict:
+ return {
+ **_get_settings_params_mapping(),
"MAX_SUPPLIER_ID_SEQ": _get_max_supplierid_seq(iteration),
"MAX_PRODUCT_ID_SEQ": _get_max_productid_seq(iteration),
"MAX_CUSTOMER_ID_SEQ": _get_max_customerid_seq(iteration),
@@ -104,9 +117,62 @@ def _get_params_mapping(iteration: int = 0) -> dict:
}
+def _metric_cumulative_shift(iteration: int) -> tuple[float, float]:
+ """Compute cumulative metric shifts at a given iteration for Metric_Trend monitors.
+
+ Returns (discount_shift, price_shift) β the total shift from baseline
+ that should be applied to the underlying data at this iteration.
+ Uses composite sine waves for organic-looking oscillation patterns.
+ """
+ i = iteration
+ discount = -1.0 + 1.8 * math.sin(2 * math.pi * i / 14 + math.pi) + 0.7 * math.sin(2 * math.pi * i / 6 + math.pi + 0.5)
+ price = 80 * math.sin(2 * math.pi * i / 16) + 40 * math.sin(2 * math.pi * i / 7 + 0.3) + 100
+ return discount, price
+
+
+def _get_monitor_params_mapping(run_date: datetime, iteration: int = 0, weekday_morning_count: int = 0) -> dict:
+ # Volume: linear growth with jitter, spike at specific iteration for anomaly
+ if iteration == 60:
+ new_sales = 100
+ else:
+ new_sales = random.randint(5, 15) # noqa: S311
+
+ # Freshness: weekday morning updates with 1-day outage after schedule goes active
+ is_weekday = run_date.weekday() < 5
+ is_morning = run_date.hour < 12
+ is_outage = weekday_morning_count == 21
+ is_update_suppliers_iter = is_weekday and is_morning and not is_outage
+
+ # Metrics: compute deltas for discount and price shifts
+ curr_discount, curr_price = _metric_cumulative_shift(iteration)
+ prev_discount, prev_price = _metric_cumulative_shift(iteration - 1) if iteration > 1 else (0.0, 0.0)
+ discount_delta = round(curr_discount - prev_discount, 3)
+ price_delta = round(curr_price - prev_price, 2)
+
+ return {
+ **_get_settings_params_mapping(),
+ "ITERATION_NUMBER": iteration,
+ "RUN_DATE": run_date,
+ "NEW_SALES": new_sales,
+ "IS_ADD_CUSTOMER_COL_ITER": iteration == 47,
+ "IS_DELETE_CUSTOMER_COL_ITER": iteration == 58,
+ "IS_UPDATE_PRODUCT_ITER": not 24 < iteration < 28,
+ "IS_CREATE_RETURNS_TABLE_ITER": iteration == 52,
+ "IS_DELETE_CUSTOMER_ITER": iteration in (29, 36, 55),
+ "IS_UPDATE_SUPPLIERS_ITER": is_update_suppliers_iter,
+ "DISCOUNT_DELTA": discount_delta,
+ "PRICE_DELTA": price_delta,
+ }
+
+
+def _get_quick_start_query(template_file_name: str, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
+ template = read_template_sql_file(template_file_name, "quick_start")
+ return apply_params(template, params), params
+
+
def run_quick_start(delete_target_db: bool) -> None:
# Init
- params_mapping = _get_params_mapping()
+ params_mapping = _get_quick_start_params_mapping()
_prepare_connection_to_target_database(params_mapping)
# Create DB
@@ -124,16 +190,18 @@ def run_quick_start(delete_target_db: bool) -> None:
app_db_params = get_app_db_params_mapping()
execute_db_queries(
[
- (replace_params(read_template_sql_file("initial_data_seeding.sql", "quick_start"), app_db_params), app_db_params),
+ _get_quick_start_query("initial_data_seeding.sql", app_db_params),
],
)
+ with_database_session(_setup_initial_config)()
+
# Schema and Populate target db
click.echo(f"Populating target db : {target_db_name}")
execute_db_queries(
[
- (replace_params(read_template_sql_file("recreate_target_data_schema.sql", "quick_start"), params_mapping), params_mapping),
- (replace_params(read_template_sql_file("populate_target_data.sql", "quick_start"), params_mapping), params_mapping),
+ _get_quick_start_query("recreate_target_data_schema.sql", params_mapping),
+ _get_quick_start_query("populate_target_data.sql", params_mapping),
],
use_target_db=True,
)
@@ -145,10 +213,15 @@ def run_quick_start(delete_target_db: bool) -> None:
)
)
with_database_session(score_definition.save)()
+ with_database_session(run_monitor_generation)("823a1fef-9b6d-48d5-9d0f-2db9812cc318", ["Volume_Trend", "Schema_Drift"])
+
+
+def _setup_initial_config():
+ PersistedSetting.set("SMTP_CONFIGURED", smtp_configured())
def run_quick_start_increment(iteration):
- params_mapping = _get_params_mapping(iteration)
+ params_mapping = _get_quick_start_params_mapping(iteration)
_prepare_connection_to_target_database(params_mapping)
target_db_name = params_mapping["PROJECT_DB"]
@@ -156,14 +229,29 @@ def run_quick_start_increment(iteration):
execute_db_queries(
[
- (replace_params(read_template_sql_file("update_target_data.sql", "quick_start"), params_mapping), params_mapping),
- (replace_params(read_template_sql_file(f"update_target_data_iter{iteration}.sql", "quick_start"), params_mapping), params_mapping),
+ _get_quick_start_query("update_target_data.sql", params_mapping),
+ _get_quick_start_query(f"update_target_data_iter{iteration}.sql", params_mapping),
],
use_target_db=True,
)
setup_cat_tests(iteration)
+def run_monitor_increment(run_date, iteration, weekday_morning_count=0):
+ params_mapping = _get_monitor_params_mapping(run_date, iteration, weekday_morning_count)
+ _prepare_connection_to_target_database(params_mapping)
+
+ target_db_name = params_mapping["PROJECT_DB"]
+ LOG.info(f"Incremental monitor updates of target db ({iteration}) : {target_db_name}")
+
+ execute_db_queries(
+ [
+ _get_quick_start_query("run_monitor_iteration.sql", params_mapping),
+ ],
+ use_target_db=True,
+ )
+
+
def setup_cat_tests(iteration):
if iteration == 0:
return
@@ -172,12 +260,11 @@ def setup_cat_tests(iteration):
elif iteration >=1:
sql_file = "update_cat_tests.sql"
- params_mapping = _get_params_mapping(iteration)
- query = replace_params(read_template_sql_file(sql_file, "quick_start"), params_mapping)
+ params_mapping = _get_quick_start_params_mapping(iteration)
execute_db_queries(
[
- (query, params_mapping),
+ _get_quick_start_query(sql_file, params_mapping),
],
use_target_db=False,
)
diff --git a/testgen/commands/run_test_execution.py b/testgen/commands/run_test_execution.py
index 14403f46..a809ad20 100644
--- a/testgen/commands/run_test_execution.py
+++ b/testgen/commands/run_test_execution.py
@@ -1,6 +1,7 @@
import logging
import subprocess
import threading
+from collections import defaultdict
from datetime import UTC, datetime, timedelta
from functools import partial
from typing import Literal
@@ -11,6 +12,8 @@
from testgen.commands.queries.execute_tests_query import TestExecutionDef, TestExecutionSQL
from testgen.commands.queries.rollup_scores_query import RollupScoresSQL
from testgen.commands.run_refresh_score_cards_results import run_refresh_score_cards_results
+from testgen.commands.test_generation import run_monitor_generation
+from testgen.commands.test_thresholds_prediction import TestThresholdsPrediction
from testgen.common import (
execute_db_queries,
fetch_dict_from_db,
@@ -25,6 +28,7 @@
from testgen.common.models.table_group import TableGroup
from testgen.common.models.test_run import TestRun
from testgen.common.models.test_suite import TestSuite
+from testgen.common.notifications.monitor_run import send_monitor_notifications
from testgen.common.notifications.test_run import send_test_run_notifications
from testgen.ui.session import session
from testgen.utils import get_exception_message
@@ -80,11 +84,14 @@ def run_test_execution(test_suite_id: str | UUID, username: str | None = None, r
data_chars = run_data_chars_refresh(connection, table_group, test_run.test_starttime)
test_run.set_progress("data_chars", "Completed")
- sql_generator = TestExecutionSQL(connection, table_group, test_run)
+ sql_generator = TestExecutionSQL(connection, table_group, test_suite, test_run)
+
+ if test_suite.is_monitor:
+ _sync_monitor_definitions(sql_generator)
# Update the thresholds before retrieving the test definitions in the next steps
- LOG.info("Updating historic test thresholds")
- execute_db_queries([sql_generator.update_historic_thresholds()])
+ LOG.info("Updating test thresholds based on history calculations")
+ execute_db_queries([sql_generator.update_history_calc_thresholds()])
LOG.info("Retrieving active test definitions in test suite")
test_defs = fetch_dict_from_db(*sql_generator.get_active_test_definitions())
@@ -116,7 +123,7 @@ def run_test_execution(test_suite_id: str | UUID, username: str | None = None, r
# Run metadata tests last so that results for other tests are available to them
for run_type in ["QUERY", "CAT", "METADATA"]:
if (run_test_defs := [td for td in valid_test_defs if td.run_type == run_type]):
- run_functions[run_type](run_test_defs)
+ run_functions[run_type](run_test_defs, save_progress=not test_suite.is_monitor)
else:
test_run.set_progress(run_type, "Completed")
LOG.info(f"No {run_type} tests to run")
@@ -150,17 +157,27 @@ def run_test_execution(test_suite_id: str | UUID, username: str | None = None, r
test_suite.last_complete_test_run_id = test_run.id
test_suite.save()
- send_test_run_notifications(test_run)
- _rollup_test_scores(test_run, table_group)
+ if not test_suite.is_monitor:
+ send_test_run_notifications(test_run)
+ _rollup_test_scores(test_run, table_group)
+ else:
+ send_monitor_notifications(test_run)
finally:
+ scoring_endtime = datetime.now(UTC) + time_delta
+ try:
+ TestThresholdsPrediction(test_suite, test_run.test_starttime).run()
+ except Exception:
+ LOG.exception("Error predicting test thresholds")
+
MixpanelService().send_event(
- "run-tests",
+ "run-monitors" if test_suite.is_monitor else "run-tests",
source=settings.ANALYTICS_JOB_SOURCE,
username=username,
sql_flavor=connection.sql_flavor_code,
test_count=test_run.test_ct,
run_duration=(test_run.test_endtime - test_run.test_starttime.replace(tzinfo=UTC)).total_seconds(),
- scoring_duration=(datetime.now(UTC) + time_delta - test_run.test_endtime).total_seconds(),
+ scoring_duration=(scoring_endtime - test_run.test_endtime).total_seconds(),
+ prediction_duration=(datetime.now(UTC) + time_delta - scoring_endtime).total_seconds(),
)
return f"""
@@ -169,7 +186,32 @@ def run_test_execution(test_suite_id: str | UUID, username: str | None = None, r
"""
-def _run_tests(sql_generator: TestExecutionSQL, run_type: Literal["QUERY", "METADATA"], test_defs: list[TestExecutionDef]) -> None:
+def _sync_monitor_definitions(sql_generator: TestExecutionSQL) -> None:
+ test_suite_id = sql_generator.test_run.test_suite_id
+
+ schema_changes = fetch_dict_from_db(*sql_generator.has_schema_changes())[0]
+ if schema_changes["has_table_drops"]:
+ run_monitor_generation(test_suite_id, ["Freshness_Trend", "Volume_Trend", "Metric_Trend"], mode="delete")
+ if schema_changes["has_table_adds"]:
+ # Freshness monitors will be inserted after profiling
+ run_monitor_generation(test_suite_id, ["Volume_Trend"], mode="insert")
+
+ # Regenerate monitors that errored in previous run
+ errored_monitors = fetch_dict_from_db(*sql_generator.get_errored_autogen_monitors())
+ if errored_monitors:
+ errored_by_type: dict[str, list[str]] = defaultdict(list)
+ for row in errored_monitors:
+ errored_by_type[row["test_type"]].append(row["table_name"])
+ for test_type, table_names in errored_by_type.items():
+ run_monitor_generation(test_suite_id, [test_type], mode="upsert", table_names=table_names)
+
+
+def _run_tests(
+ sql_generator: TestExecutionSQL,
+ run_type: Literal["QUERY", "METADATA"],
+ test_defs: list[TestExecutionDef],
+ save_progress: bool = False,
+) -> None:
test_run = sql_generator.test_run
test_run.set_progress(run_type, "Running")
test_run.save()
@@ -190,7 +232,7 @@ def update_test_progress(progress: ThreadedProgress) -> None:
[sql_generator.run_query_test(td) for td in test_defs],
use_target_db=run_type != "METADATA",
max_threads=sql_generator.connection.max_threads,
- progress_callback=update_test_progress,
+ progress_callback=update_test_progress if save_progress else None,
)
if test_results:
@@ -215,7 +257,11 @@ def update_test_progress(progress: ThreadedProgress) -> None:
)
-def _run_cat_tests(sql_generator: TestExecutionSQL, test_defs: list[TestExecutionDef]) -> None:
+def _run_cat_tests(
+ sql_generator: TestExecutionSQL,
+ test_defs: list[TestExecutionDef],
+ save_progress: bool = False,
+) -> None:
test_run = sql_generator.test_run
test_run.set_progress("CAT", "Running")
test_run.save()
@@ -241,7 +287,7 @@ def update_aggegate_progress(progress: ThreadedProgress) -> None:
aggregate_queries,
use_target_db=True,
max_threads=sql_generator.connection.max_threads,
- progress_callback=update_aggegate_progress,
+ progress_callback=update_aggegate_progress if save_progress else None,
)
if aggregate_results:
@@ -281,7 +327,7 @@ def update_single_progress(progress: ThreadedProgress) -> None:
single_queries,
use_target_db=True,
max_threads=sql_generator.connection.max_threads,
- progress_callback=update_single_progress,
+ progress_callback=update_single_progress if save_progress else False,
)
if single_results:
diff --git a/testgen/commands/test_generation.py b/testgen/commands/test_generation.py
new file mode 100644
index 00000000..0583112c
--- /dev/null
+++ b/testgen/commands/test_generation.py
@@ -0,0 +1,211 @@
+import dataclasses
+import logging
+from datetime import UTC, datetime, timedelta
+from typing import Literal
+from uuid import UUID
+
+from testgen import settings
+from testgen.common.database.database_service import (
+ execute_db_queries,
+ fetch_dict_from_db,
+ get_flavor_service,
+ replace_params,
+)
+from testgen.common.mixpanel_service import MixpanelService
+from testgen.common.models.connection import Connection
+from testgen.common.models.table_group import TableGroup
+from testgen.common.models.test_suite import TestSuite
+from testgen.common.read_file import read_template_sql_file
+from testgen.utils import to_sql_timestamp
+
+LOG = logging.getLogger("testgen")
+
+GenerationSet = Literal["Standard", "Monitor"]
+MonitorTestType = Literal["Freshness_Trend", "Volume_Trend", "Schema_Drift"]
+MonitorGenerationMode = Literal["upsert", "insert", "delete"]
+
+@dataclasses.dataclass
+class TestTypeParams:
+ test_type: str
+ selection_criteria: str | None
+ generation_template: str | None
+ default_parm_columns: str | None
+ default_parm_values: str | None
+
+
+# Generate tests for a regular non-monitor test suite - don't use for monitors
+def run_test_generation(
+ test_suite_id: str | UUID,
+ generation_set: GenerationSet = "Standard",
+ test_types: list[str] | None = None,
+) -> str:
+ if test_suite_id is None:
+ raise ValueError("Test Suite ID was not specified")
+
+ LOG.info(f"Starting test generation for test suite {test_suite_id}")
+
+ LOG.info("Retrieving connection, table group, and test suite parameters")
+ test_suite = TestSuite.get(test_suite_id)
+ if test_suite.is_monitor:
+ raise ValueError("Cannot run regular test generation for monitor suite")
+ table_group = TableGroup.get(test_suite.table_groups_id)
+ connection = Connection.get(table_group.connection_id)
+
+ success = False
+ try:
+ TestGeneration(connection, table_group, test_suite, generation_set, test_types).run()
+ success = True
+ except Exception:
+ LOG.exception("Test generation encountered an error.")
+ finally:
+ MixpanelService().send_event(
+ "generate-tests",
+ source=settings.ANALYTICS_JOB_SOURCE,
+ sql_flavor=connection.sql_flavor,
+ generation_set=generation_set,
+ )
+
+ return "Test generation completed." if success else "Test generation encountered an error. Check log for details."
+
+
+def run_monitor_generation(
+ monitor_suite_id: str | UUID,
+ monitors: list[MonitorTestType],
+ mode: MonitorGenerationMode = "upsert",
+ table_names: list[str] | None = None,
+) -> None:
+ """
+ Modes:
+ - "upsert": Add tests for new tables + update tests for existing tables + no deletion
+ - "insert": Only add tests for new tables
+ - "delete": Only delete tests for dropped tables
+ """
+ if monitor_suite_id is None:
+ raise ValueError("Monitor Suite ID was not specified")
+
+ LOG.info(f"Starting monitor generation for {monitor_suite_id} (Mode = {mode}, Monitors = {monitors})")
+
+ monitor_suite = TestSuite.get(monitor_suite_id)
+ if not monitor_suite.is_monitor:
+ raise ValueError("Cannot run monitor generation for regular test suite")
+ table_group = TableGroup.get(monitor_suite.table_groups_id)
+ connection = Connection.get(table_group.connection_id)
+
+ TestGeneration(connection, table_group, monitor_suite, "Monitor", monitors).monitor_run(mode, table_names=table_names)
+
+
+class TestGeneration:
+
+ def __init__(
+ self,
+ connection: Connection,
+ table_group: TableGroup,
+ test_suite: TestSuite,
+ generation_set: str,
+ test_types_filter: list[MonitorTestType] | None = None,
+ ):
+ self.connection = connection
+ self.table_group = table_group
+ self.test_suite = test_suite
+ self.generation_set = generation_set
+ self.test_types_filter = test_types_filter
+ self.flavor = connection.sql_flavor
+ self.flavor_service = get_flavor_service(self.flavor)
+
+ self.run_date = datetime.now(UTC)
+ self.as_of_date = self.run_date
+ if (delay_days := int(self.table_group.profiling_delay_days)):
+ self.as_of_date = self.run_date - timedelta(days=delay_days)
+
+ def run(self) -> None:
+ LOG.info("Running test generation queries")
+ execute_db_queries([
+ *self._get_generation_queries(),
+ self._get_query("delete_stale_autogen_tests.sql"),
+ ])
+
+ def monitor_run(self, mode: MonitorGenerationMode, table_names: list[str] | None = None) -> None:
+ if mode == "delete":
+ execute_db_queries([self._get_query("delete_stale_monitors.sql")])
+ return
+
+ extra_params = {"INSERT_ONLY": mode == "insert"}
+ if table_names:
+ table_list = ", ".join(f"'{table}'" for table in table_names)
+ extra_params["TABLE_FILTER"] = f"AND table_name IN ({table_list})"
+
+ LOG.info("Running monitor generation queries")
+ execute_db_queries(
+ self._get_generation_queries(extra_params=extra_params),
+ )
+
+ def _get_generation_queries(self, extra_params: dict | None = None) -> list[tuple[str, dict]]:
+ test_types = fetch_dict_from_db(*self._get_query("get_test_types.sql", extra_params=extra_params))
+ test_types = [TestTypeParams(**item) for item in test_types]
+
+ if self.test_types_filter:
+ test_types = [tt for tt in test_types if tt.test_type in self.test_types_filter]
+
+ selection_queries = [
+ self._get_query("gen_selection_tests.sql", test_type=tt, extra_params=extra_params)
+ for tt in test_types
+ if tt.selection_criteria and tt.selection_criteria != "TEMPLATE"
+ ]
+
+ template_queries = []
+ for tt in test_types:
+ if template_file := tt.generation_template:
+ # Try flavor-specific template first, then fall back to generic
+ for directory in [f"flavors/{self.flavor}/gen_query_tests", "gen_query_tests", "gen_funny_cat_tests"]:
+ try:
+ template_queries.append(self._get_query(template_file, directory, extra_params=extra_params))
+ break
+ except (ValueError, ModuleNotFoundError):
+ continue
+ else:
+ LOG.warning(f"Template file '{template_file}' not found for test type '{tt.test_type}'")
+
+ return [*selection_queries, *template_queries]
+
+ def _get_params(self, test_type: TestTypeParams | None = None) -> dict:
+ params = {}
+ if test_type:
+ params.update({
+ "TEST_TYPE": test_type.test_type,
+ # Replace these first since they may contain other params
+ "SELECTION_CRITERIA": test_type.selection_criteria,
+ "DEFAULT_PARM_COLUMNS": test_type.default_parm_columns,
+ "DEFAULT_PARM_COLUMNS_UPDATE": ",".join([
+ f"{column} = EXCLUDED.{column.strip()}"
+ for column in test_type.default_parm_columns.split(",")
+ ]) if test_type.default_parm_columns else "",
+ "DEFAULT_PARM_VALUES": test_type.default_parm_values,
+ })
+ params.update({
+ "TABLE_GROUPS_ID": self.table_group.id,
+ "TEST_SUITE_ID": self.test_suite.id,
+ "DATA_SCHEMA": self.table_group.table_group_schema,
+ "GENERATION_SET": self.generation_set,
+ "TEST_TYPES_FILTER": self.test_types_filter,
+ "RUN_DATE": to_sql_timestamp(self.run_date),
+ "AS_OF_DATE": to_sql_timestamp(self.as_of_date),
+ "SQL_FLAVOR": self.flavor,
+ "QUOTE": self.flavor_service.quote_character,
+ "INSERT_ONLY": False,
+ "TABLE_FILTER": "",
+ })
+ return params
+
+ def _get_query(
+ self,
+ template_file_name: str,
+ sub_directory: str | None = "generation",
+ test_type: TestTypeParams | None = None,
+ extra_params: dict | None = None,
+ ) -> tuple[str, dict | None]:
+ query = read_template_sql_file(template_file_name, sub_directory)
+ params = self._get_params(test_type)
+ if extra_params:
+ params.update(extra_params)
+ query = replace_params(query, params)
+ return query, params
diff --git a/testgen/commands/test_thresholds_prediction.py b/testgen/commands/test_thresholds_prediction.py
new file mode 100644
index 00000000..ca7b679b
--- /dev/null
+++ b/testgen/commands/test_thresholds_prediction.py
@@ -0,0 +1,302 @@
+import json
+import logging
+from datetime import datetime
+
+import pandas as pd
+from scipy import stats
+
+from testgen.common.database.database_service import (
+ execute_db_queries,
+ fetch_dict_from_db,
+ replace_params,
+ write_to_app_db,
+)
+from testgen.common.freshness_service import (
+ get_freshness_gap_threshold,
+ infer_schedule,
+ minutes_to_next_deadline,
+ resolve_holiday_dates,
+)
+from testgen.common.models import with_database_session
+from testgen.common.models.scheduler import JobSchedule
+from testgen.common.models.test_suite import PredictSensitivity, TestSuite
+from testgen.common.read_file import read_template_sql_file
+from testgen.common.time_series_service import (
+ NotEnoughData,
+ get_sarimax_forecast,
+)
+from testgen.utils import to_dataframe, to_sql_timestamp
+
+LOG = logging.getLogger("testgen")
+
+NUM_FORECAST = 10
+T_DISTRIBUTION_THRESHOLD = 20
+
+Z_SCORE_MAP = {
+ ("lower_tolerance", PredictSensitivity.low): -3.0, # 0.13th percentile
+ ("lower_tolerance", PredictSensitivity.medium): -2.5, # 0.62nd percentile
+ ("lower_tolerance", PredictSensitivity.high): -2.0, # 2.3rd percentile
+ ("upper_tolerance", PredictSensitivity.high): 2.0, # 97.7th percentile
+ ("upper_tolerance", PredictSensitivity.medium): 2.5, # 99.4th percentile
+ ("upper_tolerance", PredictSensitivity.low): 3.0, # 99.87th percentile
+}
+
+FRESHNESS_THRESHOLD_MAP = {
+ # upper_pct floor_mult lower_pct
+ PredictSensitivity.high: (80, 1.0, 20),
+ PredictSensitivity.medium: (95, 1.25, 10),
+ PredictSensitivity.low: (99, 1.5, 5),
+}
+
+SCHEDULE_DEADLINE_BUFFER_HOURS = {
+ PredictSensitivity.high: 1.5,
+ PredictSensitivity.medium: 3.0,
+ PredictSensitivity.low: 5.0,
+}
+
+STALENESS_FACTOR_MAP = {
+ PredictSensitivity.high: 0.75,
+ PredictSensitivity.medium: 0.85,
+ PredictSensitivity.low: 0.95,
+}
+
+
+class TestThresholdsPrediction:
+ staging_table = "stg_test_definition_updates"
+ staging_columns = (
+ "test_suite_id",
+ "test_definition_id",
+ "run_date",
+ "lower_tolerance",
+ "upper_tolerance",
+ "threshold_value",
+ "prediction",
+ )
+
+ @with_database_session
+ def __init__(self, test_suite: TestSuite, run_date: datetime):
+ self.test_suite = test_suite
+ self.run_date = run_date
+ schedule = JobSchedule.get(JobSchedule.kwargs["test_suite_id"].astext == str(test_suite.id))
+ self.tz = schedule.cron_tz or "UTC" if schedule else None
+
+ def run(self) -> None:
+ LOG.info("Retrieving historical test results for training prediction models")
+ test_results = fetch_dict_from_db(*self._get_query("get_historical_test_results.sql"))
+ if test_results:
+ df = to_dataframe(test_results, coerce_float=True)
+ grouped_dfs = df.groupby("test_definition_id", group_keys=False)
+
+ LOG.info(f"Training prediction models for tests: {len(grouped_dfs)}")
+ prediction_results = []
+ for test_def_id, group in grouped_dfs:
+ test_type = group["test_type"].iloc[0]
+ history = group[["test_time", "result_signal"]]
+ history = history.set_index("test_time")
+
+ test_prediction = [
+ self.test_suite.id,
+ test_def_id,
+ to_sql_timestamp(self.run_date),
+ ]
+ if test_type == "Freshness_Trend":
+ lower, upper, staleness, prediction = compute_freshness_threshold(
+ history,
+ sensitivity=self.test_suite.predict_sensitivity or PredictSensitivity.medium,
+ min_lookback=self.test_suite.predict_min_lookback or 1,
+ exclude_weekends=self.test_suite.predict_exclude_weekends,
+ holiday_codes=self.test_suite.holiday_codes_list,
+ schedule_tz=self.tz,
+ )
+ test_prediction.extend([lower, upper, staleness, prediction])
+ else:
+ lower, upper, prediction = compute_sarimax_threshold(
+ history,
+ sensitivity=self.test_suite.predict_sensitivity or PredictSensitivity.medium,
+ min_lookback=self.test_suite.predict_min_lookback or 1,
+ exclude_weekends=self.test_suite.predict_exclude_weekends,
+ holiday_codes=self.test_suite.holiday_codes_list,
+ schedule_tz=self.tz,
+ )
+ if test_type == "Volume_Trend":
+ if lower is not None:
+ lower = max(lower, 0.0)
+ if upper is not None:
+ upper = max(upper, 0.0)
+ test_prediction.extend([lower, upper, None, prediction])
+
+ prediction_results.append(test_prediction)
+
+ LOG.info("Writing predicted test thresholds to staging")
+ write_to_app_db(prediction_results, self.staging_columns, self.staging_table)
+
+ LOG.info("Updating predicted test thresholds and deleting staging")
+ execute_db_queries([
+ self._get_query("update_predicted_test_thresholds.sql"),
+ self._get_query("delete_staging_test_definitions.sql"),
+ ])
+
+ def _get_query(
+ self,
+ template_file_name: str,
+ sub_directory: str | None = "prediction",
+ ) -> tuple[str, dict]:
+ params = {
+ "TEST_SUITE_ID": self.test_suite.id,
+ "RUN_DATE": to_sql_timestamp(self.run_date),
+ }
+ query = read_template_sql_file(template_file_name, sub_directory)
+ query = replace_params(query, params)
+ return query, params
+
+
+def compute_freshness_threshold(
+ history: pd.DataFrame,
+ sensitivity: PredictSensitivity,
+ min_lookback: int = 1,
+ exclude_weekends: bool = False,
+ holiday_codes: list[str] | None = None,
+ schedule_tz: str | None = None,
+) -> tuple[float | None, float | None, float | None, str | None]:
+ """Compute freshness gap thresholds in business minutes.
+
+ Returns (lower, upper, staleness_threshold, prediction_json) in business minutes,
+ or (None, None, None, None) if not enough data.
+ """
+ if len(history) < min_lookback:
+ return None, None, None, None
+
+ upper_percentile, floor_multiplier, lower_percentile = FRESHNESS_THRESHOLD_MAP[sensitivity]
+ staleness_factor = STALENESS_FACTOR_MAP[sensitivity]
+
+ try:
+ result = get_freshness_gap_threshold(
+ history,
+ upper_percentile=upper_percentile,
+ floor_multiplier=floor_multiplier,
+ lower_percentile=lower_percentile,
+ exclude_weekends=exclude_weekends,
+ holiday_codes=holiday_codes,
+ tz=schedule_tz,
+ staleness_factor=staleness_factor,
+ )
+ except NotEnoughData:
+ return None, None, None, None
+
+ lower, upper = result.lower, result.upper
+ staleness: float | None = None
+ prediction_data: dict = {}
+
+ if not schedule_tz:
+ return lower, upper, staleness, json.dumps(prediction_data)
+
+ # --- Schedule inference ---
+ deadline_buffer = SCHEDULE_DEADLINE_BUFFER_HOURS[sensitivity]
+
+ schedule = infer_schedule(history, schedule_tz)
+ if not schedule:
+ return lower, upper, staleness, json.dumps(prediction_data)
+
+ prediction_data.update({
+ "schedule_stage": schedule.stage,
+ "frequency": schedule.frequency,
+ "active_days": sorted(schedule.active_days) if schedule.active_days else None,
+ "window_start": schedule.window_start,
+ "window_end": schedule.window_end,
+ # Metadata stored for debugging purposes
+ "confidence": round(schedule.confidence, 4),
+ "num_events": schedule.num_events,
+ "sensitivity": sensitivity.value,
+ "deadline_buffer_hours": deadline_buffer,
+ })
+
+ if schedule.stage == "active":
+ excluded_days = frozenset(range(7)) - schedule.active_days if schedule.active_days else None
+
+ # For sub-daily schedules, apply window exclusion for overnight gaps
+ has_window = (
+ schedule.frequency == "sub_daily"
+ and schedule.window_start is not None
+ and schedule.window_end is not None
+ )
+
+ # Recompute gap thresholds with schedule-aware exclusion
+ if excluded_days or has_window:
+ try:
+ result = get_freshness_gap_threshold(
+ history,
+ upper_percentile=upper_percentile,
+ floor_multiplier=floor_multiplier,
+ lower_percentile=lower_percentile,
+ exclude_weekends=exclude_weekends,
+ holiday_codes=holiday_codes,
+ tz=schedule_tz,
+ staleness_factor=staleness_factor,
+ excluded_days=excluded_days,
+ window_start=schedule.window_start if has_window else None,
+ window_end=schedule.window_end if has_window else None,
+ )
+ lower, upper = result.lower, result.upper
+ staleness = result.staleness
+ except NotEnoughData:
+ pass # Keep first-pass thresholds
+
+ # Override upper threshold with schedule-based deadline (daily/weekly only)
+ if schedule.frequency != "sub_daily":
+ holiday_dates = resolve_holiday_dates(holiday_codes, history.index) if holiday_codes else None
+ schedule_upper = minutes_to_next_deadline(
+ result.last_update, schedule,
+ exclude_weekends, holiday_dates, schedule_tz,
+ deadline_buffer, excluded_days=excluded_days,
+ )
+ if schedule_upper is not None:
+ upper = schedule_upper
+
+ return lower, upper, staleness, json.dumps(prediction_data)
+
+
+def compute_sarimax_threshold(
+ history: pd.DataFrame,
+ sensitivity: PredictSensitivity,
+ num_forecast: int = NUM_FORECAST,
+ min_lookback: int = 1,
+ exclude_weekends: bool = False,
+ holiday_codes: list[str] | None = None,
+ schedule_tz: str | None = None,
+) -> tuple[float | None, float | None, str | None]:
+ """Compute SARIMAX-based thresholds for the next forecast point.
+
+ Returns (lower, upper, forecast_json) or (None, None, None) if insufficient data.
+ """
+ if len(history) < min_lookback:
+ return None, None, None
+
+ try:
+ forecast = get_sarimax_forecast(
+ history,
+ num_forecast=num_forecast,
+ exclude_weekends=exclude_weekends,
+ holiday_codes=holiday_codes,
+ tz=schedule_tz,
+ )
+
+ num_points = len(history)
+ for key, z_score in Z_SCORE_MAP.items():
+ if num_points < T_DISTRIBUTION_THRESHOLD:
+ percentile = stats.norm.cdf(z_score)
+ multiplier = stats.t.ppf(percentile, df=num_points - 1)
+ else:
+ multiplier = z_score
+ column = f"{key[0]}|{key[1].value}"
+ forecast[column] = forecast["mean"] + (multiplier * forecast["se"])
+
+ next_date = forecast.index[0]
+ lower_tolerance = forecast.at[next_date, f"lower_tolerance|{sensitivity.value}"]
+ upper_tolerance = forecast.at[next_date, f"upper_tolerance|{sensitivity.value}"]
+
+ if pd.isna(lower_tolerance) or pd.isna(upper_tolerance):
+ return None, None, None
+ else:
+ return float(lower_tolerance), float(upper_tolerance), forecast.to_json()
+ except NotEnoughData:
+ return None, None, None
diff --git a/testgen/common/__init__.py b/testgen/common/__init__.py
index e413faa0..900b4c09 100644
--- a/testgen/common/__init__.py
+++ b/testgen/common/__init__.py
@@ -3,6 +3,5 @@
from .clean_sql import *
from .credentials import *
from .encrypt import *
-from .get_pipeline_parms import *
from .logs import *
from .read_file import *
diff --git a/testgen/common/database/database_service.py b/testgen/common/database/database_service.py
index cc712403..4b340b18 100644
--- a/testgen/common/database/database_service.py
+++ b/testgen/common/database/database_service.py
@@ -2,6 +2,7 @@
import csv
import importlib
import logging
+import re
from collections.abc import Callable, Iterable
from contextlib import suppress
from dataclasses import dataclass, field
@@ -193,7 +194,7 @@ def fetch_data(query: str, params: dict | None, index: int) -> tuple[list[Legacy
LOG.exception(f"Failed to execute threaded query: {query}")
return row_data, column_names, index, error
-
+
result_data: list[LegacyRow] = []
result_columns: list[str] = []
error_data: dict[int, str] = {}
@@ -284,12 +285,42 @@ def write_to_app_db(data: list[LegacyRow], column_names: Iterable[str], table_na
connection.close()
+def apply_params(query: str, params: dict[str, Any]) -> str:
+ query = process_conditionals(query, params)
+ query = replace_params(query, params)
+ return query
+
+
def replace_params(query: str, params: dict[str, Any]) -> str:
for key, value in params.items():
query = query.replace(f"{{{key}}}", "" if value is None else str(value))
return query
+def process_conditionals(query: str, params: dict[str, Any]) -> str:
+ re_pattern = re.compile(r"^--\s+TG-(IF|ELSE|ENDIF)(?:\s+(\w+))?\s*$")
+ condition = None
+ updated_query = []
+ for line in query.splitlines(True):
+ if re_match := re_pattern.match(line):
+ match re_match.group(1):
+ case "IF" if condition is None and (variable := re_match.group(2)) is not None:
+ condition = bool(params.get(variable))
+ case "ELSE" if condition is not None:
+ condition = not condition
+ case "ENDIF" if condition is not None:
+ condition = None
+ case _:
+ raise ValueError("Template conditional misused")
+ elif condition is not False:
+ updated_query.append(line)
+
+ if condition is not None:
+ raise ValueError("Template conditional misused")
+
+ return "".join(updated_query)
+
+
def get_queries_for_command(
sub_directory: str, params: dict[str, Any], mask: str = r"^.*sql$", path: str | None = None
) -> list[str]:
diff --git a/testgen/common/freshness_service.py b/testgen/common/freshness_service.py
new file mode 100644
index 00000000..f7810787
--- /dev/null
+++ b/testgen/common/freshness_service.py
@@ -0,0 +1,609 @@
+import json
+import logging
+import zoneinfo
+from collections import Counter
+from datetime import date, datetime
+from typing import NamedTuple
+
+import numpy as np
+import pandas as pd
+
+from testgen.common.time_series_service import NotEnoughData, get_holiday_dates
+
+LOG = logging.getLogger("testgen")
+
+# Minimum completed gaps needed before freshness threshold is meaningful
+MIN_FRESHNESS_GAPS = 5
+
+# Default sliding window size β use only the most recent N gaps
+MAX_FRESHNESS_GAPS = 40
+
+
+class FreshnessThreshold(NamedTuple):
+ lower: float | None
+ upper: float
+ staleness: float
+ last_update: pd.Timestamp
+
+
+class InferredSchedule(NamedTuple):
+ stage: str # "training", "tentative", "active", "irregular"
+ frequency: str # "sub_daily", "daily", "weekly", "irregular"
+ active_days: frozenset[int] # weekday numbers (0=Mon, 6=Sun)
+ window_start: float | None # hour of day (0-24), P10
+ window_end: float | None # hour of day (0-24), P90
+ confidence: float # fraction of events matching schedule
+ num_events: int # total update events used
+
+
+def get_freshness_gap_threshold(
+ history: pd.DataFrame,
+ upper_percentile: float,
+ floor_multiplier: float,
+ lower_percentile: float,
+ exclude_weekends: bool = False,
+ holiday_codes: list[str] | None = None,
+ tz: str | None = None,
+ staleness_factor: float = 0.85,
+ excluded_days: frozenset[int] | None = None,
+ window_start: float | None = None,
+ window_end: float | None = None,
+) -> FreshnessThreshold:
+ """Compute freshness thresholds from completed gap durations.
+
+ Extracts gaps between consecutive table updates (points where result_signal == 0)
+ and returns upper and lower thresholds based on percentiles, with a floor for the
+ upper bound derived from the maximum observed gap.
+
+ When exclusion flags are set, gap durations are normalized by subtracting
+ excluded time (weekends/holidays) that fall within each gap.
+
+ A sliding window limits the number of recent gaps used, so old outliers
+ age out of the distribution over time.
+
+ :param history: DataFrame with DatetimeIndex and a result_signal column.
+ :param upper_percentile: Percentile for upper bound (e.g. 80, 95, 99).
+ :param floor_multiplier: Multiplied by max gap to set an upper floor (e.g. 1.0, 1.25, 1.5).
+ :param lower_percentile: Percentile for lower bound (e.g. 5, 10, 20).
+ :param exclude_weekends: Subtract weekend days from gap durations.
+ :param holiday_codes: Country/market codes for holidays to subtract from gap durations.
+ :param tz: IANA timezone (e.g. "America/New_York") for weekday/holiday determination.
+ :returns: FreshnessThreshold with lower (in business minutes, None if not computed),
+ upper (in business minutes), and last_update timestamp.
+ :raises NotEnoughData: If fewer than MIN_FRESHNESS_GAPS completed gaps are found.
+ """
+ signal = history.iloc[:, 0]
+ update_times = signal.index[signal == 0]
+
+ if len(update_times) - 1 < MIN_FRESHNESS_GAPS:
+ raise NotEnoughData(
+ f"Need at least {MIN_FRESHNESS_GAPS} completed gaps, found {max(len(update_times) - 1, 0)}."
+ )
+
+ has_exclusions = exclude_weekends or holiday_codes or excluded_days or (window_start is not None and window_end is not None)
+ holiday_dates = resolve_holiday_dates(holiday_codes, history.index) if holiday_codes else None
+ gaps_minutes = np.diff(update_times).astype("timedelta64[m]").astype(float)
+
+ if has_exclusions:
+ for i in range(len(gaps_minutes)):
+ excluded_minutes = count_excluded_minutes(
+ update_times[i], update_times[i + 1], exclude_weekends, holiday_dates,
+ tz=tz, excluded_days=excluded_days,
+ window_start=window_start, window_end=window_end,
+ )
+ gaps_minutes[i] = max(gaps_minutes[i] - excluded_minutes, 0)
+
+ # Sliding window: keep only the most recent gaps
+ if len(gaps_minutes) > MAX_FRESHNESS_GAPS:
+ gaps_minutes = gaps_minutes[-MAX_FRESHNESS_GAPS:]
+
+ upper = max(
+ float(np.percentile(gaps_minutes, upper_percentile)),
+ float(np.max(gaps_minutes)) * floor_multiplier,
+ )
+
+ lower = float(np.percentile(gaps_minutes, lower_percentile))
+ if lower <= 0:
+ lower = None
+
+ staleness = float(np.median(gaps_minutes)) * staleness_factor
+
+ return FreshnessThreshold(lower=lower, upper=upper, staleness=staleness, last_update=update_times[-1])
+
+
+def resolve_holiday_dates(codes: list[str], index: pd.DatetimeIndex) -> set[date]:
+ return {d.date() if isinstance(d, datetime) else d for d in get_holiday_dates(codes, index)}
+
+
+class ScheduleParams(NamedTuple):
+ excluded_days: frozenset[int] | None
+ window_start: float | None
+ window_end: float | None
+
+
+def get_schedule_params(prediction: dict | str | None) -> ScheduleParams:
+ empty = ScheduleParams(excluded_days=None, window_start=None, window_end=None)
+ if not prediction:
+ return empty
+ prediction = prediction if isinstance(prediction, dict) else json.loads(prediction)
+
+ if prediction.get("schedule_stage") != "active":
+ return empty
+
+ active_days = prediction.get("active_days")
+ excluded_days = frozenset(range(7)) - frozenset(active_days) if active_days else None
+
+ window_start: float | None = None
+ window_end: float | None = None
+ if prediction.get("frequency") == "sub_daily":
+ if (ws := prediction.get("window_start")) is not None and (we := prediction.get("window_end")) is not None:
+ window_start = float(ws)
+ window_end = float(we)
+
+ return ScheduleParams(excluded_days=excluded_days, window_start=window_start, window_end=window_end)
+
+
+def is_excluded_day(
+ dt: pd.Timestamp,
+ exclude_weekends: bool,
+ holiday_dates: set[date] | None,
+ tz: str | None = None,
+ excluded_days: frozenset[int] | None = None,
+ window_start: float | None = None,
+ window_end: float | None = None,
+) -> bool:
+ """Check if a timestamp falls on excluded time.
+
+ Excluded time includes:
+ - Weekends (if exclude_weekends is True)
+ - Holidays (if holiday_dates is provided)
+ - Inferred inactive days (if excluded_days is provided)
+ - Hours outside the active window on active days (if window_start/window_end are provided)
+
+ When tz is provided, naive timestamps are interpreted as UTC and converted
+ to the given timezone for the weekday/holiday/hour check.
+ """
+ if tz:
+ local_ts = _to_local(dt, tz)
+ date_ = local_ts.date()
+ else:
+ local_ts = dt
+ date_ = dt.date()
+
+ if exclude_weekends and date_.weekday() >= 5:
+ return True
+ if excluded_days and date_.weekday() in excluded_days:
+ return True
+ if holiday_dates and date_ in holiday_dates:
+ return True
+
+ if window_start is not None and window_end is not None:
+ hour = local_ts.hour + local_ts.minute / 60.0
+ if not _is_in_time_window(hour, window_start, window_end):
+ return True
+ return False
+
+
+def next_business_day_start(
+ dt: pd.Timestamp,
+ exclude_weekends: bool,
+ holiday_dates: set[date] | None,
+ tz: str | None = None,
+ excluded_days: frozenset[int] | None = None,
+) -> pd.Timestamp:
+ day = (pd.Timestamp(dt) + pd.DateOffset(days=1)).normalize()
+ while is_excluded_day(day, exclude_weekends, holiday_dates, tz=tz, excluded_days=excluded_days):
+ day += pd.DateOffset(days=1)
+ return day
+
+
+def count_excluded_minutes(
+ start: pd.Timestamp,
+ end: pd.Timestamp,
+ exclude_weekends: bool,
+ holiday_dates: set[date] | None,
+ tz: str | None = None,
+ excluded_days: frozenset[int] | None = None,
+ window_start: float | None = None,
+ window_end: float | None = None,
+) -> float:
+ """Count excluded minutes between two timestamps, including partial days.
+
+ Iterates day-by-day from start to end, counting the overlap between each
+ excluded day (weekend or holiday) and the [start, end] interval. Partial
+ excluded days at the boundaries are correctly prorated.
+
+ When window_start/window_end are provided (sub-daily active schedules),
+ hours outside the [window_start, window_end] range on active days are also
+ counted as excluded. Fully excluded days (weekends, holidays, inactive days)
+ still count their entire overlap as excluded β the window only applies to
+ days that are otherwise active.
+
+ When tz is provided, naive timestamps are converted to the local timezone
+ for weekday/holiday determination. The overlap is computed in naive local
+ time (timezone stripped after conversion) so that every calendar day is
+ exactly 24 h β this keeps excluded minutes consistent with UTC-based raw
+ gaps and avoids DST distortion (fall-back days counting 25 h, spring-forward
+ days counting 23 h).
+ """
+ start = pd.Timestamp(start)
+ end = pd.Timestamp(end)
+
+ if tz:
+ # Convert to local, then strip timezone β naive local time so each day
+ # is exactly 24 h. This prevents DST transitions from inflating/deflating
+ # excluded time relative to the UTC-based raw gap that callers subtract from.
+ start = _to_local(start, tz).tz_localize(None)
+ end = _to_local(end, tz).tz_localize(None)
+
+ if start >= end:
+ return 0.0
+
+ has_window = window_start is not None and window_end is not None
+
+ total_minutes = 0.0
+ day_start = start.normalize()
+
+ while day_start < end:
+ next_day = day_start + pd.Timedelta(days=1)
+
+ if is_excluded_day(day_start, exclude_weekends, holiday_dates, excluded_days=excluded_days):
+ # Full day excluded (weekend, holiday, inactive day)
+ overlap_start = max(start, day_start)
+ overlap_end = min(end, next_day)
+ total_minutes += (overlap_end - overlap_start).total_seconds() / 60
+ elif has_window:
+ # Active day but with window exclusion: exclude hours outside the window
+ # Compute the active window boundaries for this calendar day
+ win_open = day_start + pd.Timedelta(hours=window_start)
+ win_close = day_start + pd.Timedelta(hours=window_end)
+
+ # Clip to the [start, end] interval
+ overlap_start = max(start, day_start)
+ overlap_end = min(end, next_day)
+
+ # Excluded = time in [overlap_start, overlap_end] that is outside [win_open, win_close]
+ # = total overlap - time inside window
+ total_overlap = (overlap_end - overlap_start).total_seconds() / 60
+
+ # Compute overlap with the active window
+ active_start = max(overlap_start, win_open)
+ active_end = min(overlap_end, win_close)
+ active_minutes = max((active_end - active_start).total_seconds() / 60, 0)
+
+ excluded_on_day = total_overlap - active_minutes
+ if excluded_on_day > 0:
+ total_minutes += excluded_on_day
+
+ day_start = next_day
+
+ return total_minutes
+
+
+def add_business_minutes(
+ start: pd.Timestamp | datetime,
+ business_minutes: float,
+ exclude_weekends: bool,
+ holiday_dates: set[date] | None,
+ tz: str | None = None,
+ excluded_days: frozenset[int] | None = None,
+) -> pd.Timestamp:
+ """Advance wall-clock time by N business minutes, skipping excluded days.
+
+ Inverse of count_excluded_minutes: given a start time and a number of
+ business minutes to elapse, returns the wall-clock timestamp at which
+ those minutes will have passed, skipping weekends and holidays.
+
+ When tz is provided, naive timestamps are interpreted as UTC and day
+ boundary checks use the local timezone.
+ """
+ start = pd.Timestamp(start)
+ if business_minutes <= 0:
+ return start
+
+ has_exclusions = exclude_weekends or bool(holiday_dates) or bool(excluded_days)
+ if not has_exclusions:
+ return start + pd.Timedelta(minutes=business_minutes)
+
+ cursor = start
+ if tz:
+ cursor = _to_local(cursor, tz)
+
+ remaining = business_minutes
+
+ while remaining > 0:
+ day_start = cursor.normalize()
+ next_day = (day_start + pd.DateOffset(days=1)).normalize()
+
+ if is_excluded_day(cursor, exclude_weekends, holiday_dates, excluded_days=excluded_days):
+ cursor = next_day
+ # Skip consecutive excluded days
+ for _ in range(365):
+ if not is_excluded_day(cursor, exclude_weekends, holiday_dates, excluded_days=excluded_days):
+ break
+ cursor = (cursor + pd.DateOffset(days=1)).normalize()
+ continue
+
+ minutes_left_today = (next_day - cursor).total_seconds() / 60
+
+ if remaining <= minutes_left_today:
+ cursor = cursor + pd.Timedelta(minutes=remaining)
+ remaining = 0
+ else:
+ remaining -= minutes_left_today
+ cursor = next_day
+
+ if tz and start.tzinfo is None:
+ cursor = cursor.tz_convert("UTC").tz_localize(None)
+
+ return cursor
+
+
+def _is_in_time_window(hour: float, window_start: float, window_end: float) -> bool:
+ if window_start <= window_end:
+ return window_start <= hour <= window_end
+ return hour >= window_start or hour <= window_end
+
+
+def _to_local(ts: pd.Timestamp, tz: str) -> pd.Timestamp:
+ if ts.tzinfo is None:
+ ts = ts.tz_localize("UTC")
+ return ts.tz_convert(zoneinfo.ZoneInfo(tz))
+
+
+def _next_active_day(start: pd.Timestamp, active_days: frozenset[int], max_days: int = 14) -> pd.Timestamp | None:
+ candidate = start
+ for _ in range(max_days):
+ if candidate.weekday() in active_days:
+ return candidate
+ candidate += pd.Timedelta(days=1)
+ return None
+
+
+def _set_fractional_hour(ts: pd.Timestamp, fractional_hour: float) -> pd.Timestamp:
+ hour = int(fractional_hour)
+ minute = int((fractional_hour - hour) * 60)
+ return ts.replace(hour=hour, minute=minute, second=0, microsecond=0)
+
+
+def classify_frequency(gaps_hours: np.ndarray) -> str:
+ """Classify table update frequency from inter-update gaps.
+
+ Frequency does NOT gate the schedule stage β stage is determined by
+ confidence (>= 0.75 β "active"). However, frequency does affect which
+ threshold path is used: "sub_daily" enables within-window gap thresholds,
+ while other values use deadline-based thresholds.
+
+ Multi-day cadences (median 36-120h, e.g. Mon/Wed/Fri or Tue/Thu) classify
+ as "irregular" because they fall between the daily and weekly bands, but
+ they can still reach "active" stage when ``detect_active_days`` finds a
+ consistent day-of-week and time-of-day pattern.
+
+ Bands:
+ - sub_daily: median < 6h
+ - daily: 6h <= median < 36h
+ - weekly: 120h < median < 240h (roughly 5-10 days)
+ - irregular: everything else (0 gaps, 36-120h, 240h+)
+
+ :param gaps_hours: Array of gap durations in hours between consecutive updates.
+ :returns: One of "sub_daily", "daily", "weekly", "irregular".
+ """
+ if len(gaps_hours) == 0:
+ return "irregular"
+ median_gap = float(np.median(gaps_hours))
+ if median_gap < 6:
+ return "sub_daily"
+ elif median_gap < 36:
+ return "daily"
+ elif 120 < median_gap < 240:
+ return "weekly"
+ else:
+ return "irregular"
+
+
+def detect_active_days(
+ update_times: list[pd.Timestamp],
+ tz: str,
+ min_weeks: int = 3,
+) -> frozenset[int] | None:
+ """Detect which days of the week have updates.
+
+ :param update_times: Sorted list of update timestamps (UTC or naive-UTC).
+ :param tz: IANA timezone for local day-of-week mapping.
+ :param min_weeks: Minimum weeks of data needed.
+ :returns: frozenset of weekday numbers (0=Mon, 6=Sun) or None if insufficient data.
+ """
+ if len(update_times) < 2:
+ return None
+
+ local_times = [_to_local(t, tz) for t in update_times]
+
+ date_range_days = (local_times[-1] - local_times[0]).days
+ if date_range_days < min_weeks * 7:
+ return None
+
+ day_counts: Counter[int] = Counter(t.weekday() for t in local_times)
+ weeks_observed = max(1, date_range_days // 7)
+
+ active_days: set[int] = set()
+ for day in range(7):
+ count = day_counts.get(day, 0)
+ hit_rate = count / weeks_observed
+ if hit_rate >= 0.5:
+ active_days.add(day)
+
+ return frozenset(active_days) if active_days else None
+
+
+def detect_update_window(
+ update_times: list[pd.Timestamp],
+ active_days: frozenset[int],
+ tz: str,
+) -> tuple[float, float] | None:
+ """Detect the time-of-day window when updates arrive on active days.
+
+ :returns: (window_start, window_end) as hours 0-24, or None.
+ """
+ local_times = [_to_local(t, tz) for t in update_times]
+
+ hours_on_active_days = [
+ t.hour + t.minute / 60.0
+ for t in local_times
+ if t.weekday() in active_days
+ ]
+
+ if len(hours_on_active_days) < 10:
+ return None
+
+ # Handle midnight-wrapping clusters (e.g., 23:00-01:00)
+ shifted = False
+ late = sum(1 for h in hours_on_active_days if h >= 22) / len(hours_on_active_days)
+ early = sum(1 for h in hours_on_active_days if h < 3) / len(hours_on_active_days)
+ if late > 0.25 and early > 0.25:
+ hours_on_active_days = [(h + 12) % 24 for h in hours_on_active_days]
+ shifted = True
+
+ p10 = float(np.percentile(hours_on_active_days, 10))
+ p90 = float(np.percentile(hours_on_active_days, 90))
+
+ if shifted:
+ p10 = (p10 - 12) % 24
+ p90 = (p90 - 12) % 24
+
+ return (p10, p90)
+
+
+def compute_schedule_confidence(
+ update_times: list[pd.Timestamp],
+ schedule: InferredSchedule,
+ tz: str,
+) -> float:
+ """Fraction of historical updates that match the detected schedule.
+
+ An update "matches" if it falls on an active day and (if a window is defined)
+ within the P10-P90 time window.
+ """
+ if not update_times:
+ return 0.0
+
+ matching = 0
+ for t in update_times:
+ lt = _to_local(t, tz)
+ if lt.weekday() not in schedule.active_days:
+ continue
+ if schedule.window_start is not None and schedule.window_end is not None:
+ hour = lt.hour + lt.minute / 60.0
+ if not _is_in_time_window(hour, schedule.window_start, schedule.window_end):
+ continue
+ matching += 1
+ return matching / len(update_times)
+
+
+def infer_schedule(
+ history: pd.DataFrame,
+ tz: str,
+) -> InferredSchedule | None:
+ """Attempt to infer a table's update schedule from its freshness history.
+
+ :param history: DataFrame with DatetimeIndex and result_signal column (0 = update).
+ :param tz: IANA timezone for local time analysis.
+ :returns: InferredSchedule or None if insufficient data for any inference.
+ """
+ signal = history.iloc[:, 0]
+ update_times = list(signal.index[signal == 0])
+
+ if len(update_times) < 10:
+ return None
+
+ # Compute gaps in hours
+ gaps_hours = np.diff(update_times).astype("timedelta64[m]").astype(float) / 60.0
+
+ frequency = classify_frequency(gaps_hours)
+ num_events = len(update_times)
+
+ # Determine stage based on data quantity
+ local_times = [_to_local(t, tz) for t in update_times]
+ date_range_days = (local_times[-1] - local_times[0]).days
+
+ if date_range_days < 21 or num_events < 10:
+ return None # Not enough data for any inference
+
+ # Detect active days
+ active_days = detect_active_days(update_times, tz)
+ if active_days is None:
+ active_days = frozenset(range(7))
+
+ # Detect update window
+ window_result = detect_update_window(update_times, active_days, tz)
+ window_start = window_result[0] if window_result else None
+ window_end = window_result[1] if window_result else None
+
+ # Build preliminary schedule for confidence scoring
+ preliminary = InferredSchedule(
+ frequency=frequency,
+ active_days=active_days,
+ window_start=window_start,
+ window_end=window_end,
+ confidence=0.0,
+ num_events=num_events,
+ stage="training",
+ )
+
+ confidence = compute_schedule_confidence(update_times, preliminary, tz)
+
+ # Determine stage
+ if num_events < 20:
+ stage = "tentative"
+ elif confidence >= 0.75:
+ stage = "active"
+ elif confidence < 0.60:
+ stage = "irregular"
+ else:
+ stage = "tentative"
+
+ return preliminary._replace(confidence=confidence, stage=stage)
+
+
+def minutes_to_next_deadline(
+ last_update: pd.Timestamp,
+ schedule: InferredSchedule,
+ exclude_weekends: bool,
+ holiday_dates: set[date] | None,
+ tz: str,
+ buffer_hours: float,
+ excluded_days: frozenset[int] | None = None,
+) -> float | None:
+ if schedule.window_end is None:
+ return None
+
+ deadline_hour = (schedule.window_end + buffer_hours) % 24
+ local_last = _to_local(last_update, tz)
+
+ # Find the next active day after last_update
+ candidate = _next_active_day(local_last.normalize() + pd.Timedelta(days=1), schedule.active_days)
+ if candidate is None:
+ return None
+
+ # Set the deadline time on that day
+ deadline_ts = _set_fractional_hour(candidate, deadline_hour)
+
+ # If the deadline is already past relative to now, move to next active day
+ if deadline_ts <= local_last:
+ candidate = _next_active_day(candidate + pd.Timedelta(days=1), schedule.active_days)
+ if candidate is None:
+ return None
+ deadline_ts = _set_fractional_hour(candidate, deadline_hour)
+
+ # Convert both to UTC for consistent gap calculation
+ utc_last = local_last.tz_convert("UTC").tz_localize(None)
+ utc_deadline = deadline_ts.tz_convert("UTC").tz_localize(None)
+
+ wall_minutes = (utc_deadline - utc_last).total_seconds() / 60.0
+ if wall_minutes <= 0:
+ return None
+
+ if exclude_weekends or holiday_dates or excluded_days:
+ excl = count_excluded_minutes(utc_last, utc_deadline, exclude_weekends, holiday_dates, tz=tz, excluded_days=excluded_days)
+ return max(wall_minutes - excl, 0)
+
+ return wall_minutes
diff --git a/testgen/common/get_pipeline_parms.py b/testgen/common/get_pipeline_parms.py
deleted file mode 100644
index 1f10b364..00000000
--- a/testgen/common/get_pipeline_parms.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from typing import TypedDict
-
-from testgen.common.database.database_service import fetch_dict_from_db
-from testgen.common.read_file import read_template_sql_file
-
-
-class BaseParams(TypedDict):
- project_code: str
- connection_id: str
-
-class TestGenerationParams(BaseParams):
- export_to_observability: str
- test_suite_id: str
- profiling_as_of_date: str
-
-
-def get_test_generation_params(table_group_id: str, test_suite: str) -> TestGenerationParams:
- results = fetch_dict_from_db(
- read_template_sql_file("parms_test_gen.sql", "parms"),
- {"TABLE_GROUP_ID": table_group_id, "TEST_SUITE": test_suite},
- )
- if not results:
- raise ValueError("Connection parameters not found for test generation.")
- return TestGenerationParams(results[0])
diff --git a/testgen/common/models/entity.py b/testgen/common/models/entity.py
index 2fe3ac93..6d0b0950 100644
--- a/testgen/common/models/entity.py
+++ b/testgen/common/models/entity.py
@@ -1,5 +1,5 @@
from collections.abc import Iterable
-from dataclasses import asdict, dataclass
+from dataclasses import asdict, dataclass, fields
from typing import Any, Self
from uuid import UUID
@@ -23,7 +23,7 @@
class EntityMinimal:
@classmethod
def columns(cls) -> list[str]:
- return list(cls.__annotations__.keys())
+ return [f.name for f in fields(cls)]
def to_dict(self, json_safe: bool = False) -> dict[str, Any]:
result = asdict(self)
diff --git a/testgen/common/models/notification_settings.py b/testgen/common/models/notification_settings.py
index 66c616bb..e15349f4 100644
--- a/testgen/common/models/notification_settings.py
+++ b/testgen/common/models/notification_settings.py
@@ -35,10 +35,15 @@ class ProfilingRunNotificationTrigger(enum.Enum):
on_changes = "on_changes"
+class MonitorNotificationTrigger(enum.Enum):
+ on_anomalies = "on_anomalies"
+
+
class NotificationEvent(enum.Enum):
test_run = "test_run"
profiling_run = "profiling_run"
score_drop = "score_drop"
+ monitor_run = "monitor_run"
class NotificationSettingsValidationError(Exception):
@@ -96,7 +101,7 @@ def _base_select_query(
fk_count = len([None for fk in (test_suite_id, table_group_id, score_definition_id) if fk is not SENTINEL])
if fk_count > 1:
raise ValueError("Only one foreign key can be used at a time.")
- elif fk_count == 1 and (project_code is not SENTINEL or event is not SENTINEL):
+ elif fk_count == 1 and (project_code is not SENTINEL):
raise ValueError("Filtering by project_code or event is not allowed when filtering by a foreign key.")
query = select(cls)
@@ -284,3 +289,40 @@ def create(
)
ns.save()
return ns
+
+
+class MonitorNotificationSettings(RunNotificationSettings[TestRunNotificationTrigger]):
+
+ __mapper_args__: ClassVar = {
+ "polymorphic_identity": NotificationEvent.monitor_run,
+ }
+ trigger_enum = MonitorNotificationTrigger
+
+ @property
+ def table_name(self) -> str | None:
+ return self.settings["table_name"] if self.settings.get("table_name") else None
+
+ @table_name.setter
+ def table_name(self, value: str | None) -> None:
+ self.settings = {**self.settings, "table_name": value}
+
+ @classmethod
+ def create(
+ cls,
+ project_code: str,
+ table_group_id: UUID,
+ test_suite_id: UUID,
+ recipients: list[str],
+ trigger: TestRunNotificationTrigger,
+ table_name: str | None = None,
+ ) -> Self:
+ ns = cls(
+ event=NotificationEvent.monitor_run,
+ project_code=project_code,
+ table_group_id=table_group_id,
+ test_suite_id=test_suite_id,
+ recipients=recipients,
+ settings={"trigger": trigger.value, "table_name": table_name},
+ )
+ ns.save()
+ return ns
diff --git a/testgen/common/models/scheduler.py b/testgen/common/models/scheduler.py
index da3f7073..fa070e03 100644
--- a/testgen/common/models/scheduler.py
+++ b/testgen/common/models/scheduler.py
@@ -3,16 +3,19 @@
from typing import Any, Self
from uuid import UUID, uuid4
+import streamlit as st
from cron_converter import Cron
from sqlalchemy import Boolean, Column, String, cast, delete, func, select, update
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import InstrumentedAttribute
from testgen.common.models import Base, get_current_session
+from testgen.common.models.entity import ENTITY_HASH_FUNCS
from testgen.common.models.test_definition import TestDefinition
from testgen.common.models.test_suite import TestSuite
RUN_TESTS_JOB_KEY = "run-tests"
+RUN_MONITORS_JOB_KEY = "run-monitors"
RUN_PROFILE_JOB_KEY = "run-profile"
@@ -29,13 +32,20 @@ class JobSchedule(Base):
cron_tz: str = Column(String, nullable=False)
active: bool = Column(Boolean, default=True)
+ @classmethod
+ @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS)
+ def get(cls, *clauses) -> Self | None:
+ query = select(cls).where(*clauses)
+ return get_current_session().scalars(query).first()
+
@classmethod
def select_where(cls, *clauses, order_by: str | InstrumentedAttribute | None = None) -> Iterable[Self]:
+ test_job_keys = [RUN_TESTS_JOB_KEY, RUN_MONITORS_JOB_KEY]
test_definitions_count = (
select(cls.id)
.join(TestSuite, TestSuite.id == cast(cls.kwargs["test_suite_id"].astext, postgresql.UUID))
.join(TestDefinition, TestDefinition.test_suite_id == TestSuite.id)
- .where(cls.key == RUN_TESTS_JOB_KEY, cls.active == True)
+ .where(cls.key.in_(test_job_keys), cls.active == True)
.group_by(cls.id, TestSuite.id)
.having(func.count(TestDefinition.id) > 0)
.subquery()
@@ -45,14 +55,14 @@ def select_where(cls, *clauses, order_by: str | InstrumentedAttribute | None = N
.join(test_definitions_count, test_definitions_count.c.id == cls.id)
.where(*clauses)
)
- non_test_runs_query = select(cls).where(cls.key != RUN_TESTS_JOB_KEY, cls.active == True, *clauses)
+ non_test_runs_query = select(cls).where(cls.key.not_in(test_job_keys), cls.active == True, *clauses)
query = test_runs_query.union_all(non_test_runs_query).order_by(order_by)
return get_current_session().execute(query)
@classmethod
def delete(cls, job_id: str | UUID) -> None:
- query = delete(cls).where(JobSchedule.id == UUID(job_id))
+ query = delete(cls).where(JobSchedule.id == job_id)
db_session = get_current_session()
try:
db_session.execute(query)
@@ -60,10 +70,11 @@ def delete(cls, job_id: str | UUID) -> None:
db_session.rollback()
else:
db_session.commit()
+ cls.clear_cache()
@classmethod
def update_active(cls, job_id: str | UUID, active: bool) -> None:
- query = update(cls).where(JobSchedule.id == UUID(job_id)).values(active=active)
+ query = update(cls).where(JobSchedule.id == job_id).values(active=active)
db_session = get_current_session()
try:
db_session.execute(query)
@@ -71,10 +82,15 @@ def update_active(cls, job_id: str | UUID, active: bool) -> None:
db_session.rollback()
else:
db_session.commit()
+ cls.clear_cache()
@classmethod
def count(cls):
return get_current_session().query(cls).count()
+
+ @classmethod
+ def clear_cache(cls) -> None:
+ cls.get.clear()
def get_sample_triggering_timestamps(self, n=3) -> list[datetime]:
schedule = Cron(cron_string=self.cron_expr).schedule(timezone_str=self.cron_tz)
@@ -83,3 +99,9 @@ def get_sample_triggering_timestamps(self, n=3) -> list[datetime]:
@property
def cron_tz_str(self) -> str:
return self.cron_tz.replace("_", " ")
+
+ def save(self) -> None:
+ db_session = get_current_session()
+ db_session.add(self)
+ db_session.commit()
+ self.__class__.clear_cache()
diff --git a/testgen/common/models/table_group.py b/testgen/common/models/table_group.py
index ae24af04..938a851b 100644
--- a/testgen/common/models/table_group.py
+++ b/testgen/common/models/table_group.py
@@ -11,7 +11,6 @@
from testgen.common.models import get_current_session
from testgen.common.models.custom_types import NullIfEmptyString, YNString
from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal
-from testgen.common.models.scheduler import RUN_TESTS_JOB_KEY, JobSchedule
from testgen.common.models.scores import ScoreDefinition
from testgen.common.models.test_suite import TestSuite
@@ -28,6 +27,8 @@ class TableGroupMinimal(EntityMinimal):
profiling_exclude_mask: str
profile_use_sampling: bool
profiling_delay_days: str
+ monitor_test_suite_id: UUID | None
+ last_complete_profile_run_id: UUID | None
@dataclass
@@ -62,6 +63,25 @@ class TableGroupSummary(EntityMinimal):
latest_anomalies_likely_ct: int
latest_anomalies_possible_ct: int
latest_anomalies_dismissed_ct: int
+ monitor_test_suite_id: UUID | None
+ monitor_lookback: int | None
+ monitor_lookback_start: datetime | None
+ monitor_lookback_end: datetime | None
+ monitor_freshness_anomalies: int | None
+ monitor_schema_anomalies: int | None
+ monitor_volume_anomalies: int | None
+ monitor_metric_anomalies: int | None
+ monitor_freshness_has_errors: bool | None
+ monitor_volume_has_errors: bool | None
+ monitor_schema_has_errors: bool | None
+ monitor_metric_has_errors: bool | None
+ monitor_freshness_is_training: bool | None
+ monitor_volume_is_training: bool | None
+ monitor_metric_is_training: bool | None
+ monitor_freshness_is_pending: bool | None
+ monitor_volume_is_pending: bool | None
+ monitor_schema_is_pending: bool | None
+ monitor_metric_is_pending: bool | None
class TableGroup(Entity):
@@ -70,6 +90,11 @@ class TableGroup(Entity):
id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4)
project_code: str = Column(String, ForeignKey("projects.project_code"))
connection_id: int = Column(BigInteger, ForeignKey("connections.connection_id"))
+ default_test_suite_id: UUID | None = Column(
+ postgresql.UUID(as_uuid=True),
+ ForeignKey("test_suites.id"),
+ default=None,
+ )
monitor_test_suite_id: UUID | None = Column(
postgresql.UUID(as_uuid=True),
ForeignKey("test_suites.id"),
@@ -223,6 +248,59 @@ def select_summary(cls, project_code: str, for_dashboard: bool = False) -> Itera
anomaly_types.id = latest_anomalies.anomaly_id
)
GROUP BY latest_run.id
+ ),
+ ranked_test_runs AS (
+ SELECT
+ table_groups.id AS table_group_id,
+ test_runs.id,
+ test_runs.test_starttime,
+ COALESCE(test_suites.monitor_lookback, 1) AS lookback,
+ ROW_NUMBER() OVER (PARTITION BY test_runs.test_suite_id ORDER BY test_runs.test_starttime DESC) AS position
+ FROM table_groups
+ INNER JOIN test_runs
+ ON (test_runs.test_suite_id = table_groups.monitor_test_suite_id)
+ INNER JOIN test_suites
+ ON (table_groups.monitor_test_suite_id = test_suites.id)
+ WHERE table_groups.project_code = :project_code
+ ),
+ monitor_tables AS (
+ SELECT
+ ranked_test_runs.table_group_id,
+ SUM(CASE WHEN results.test_type = 'Freshness_Trend' AND results.result_code = 0 THEN 1 ELSE 0 END) AS freshness_anomalies,
+ SUM(CASE WHEN results.test_type = 'Schema_Drift' AND results.result_code = 0 THEN 1 ELSE 0 END) AS schema_anomalies,
+ SUM(CASE WHEN results.test_type = 'Volume_Trend' AND results.result_code = 0 THEN 1 ELSE 0 END) AS volume_anomalies,
+ SUM(CASE WHEN results.test_type = 'Metric_Trend' AND results.result_code = 0 THEN 1 ELSE 0 END) AS metric_anomalies,
+ BOOL_OR(results.result_status = 'Error') FILTER (WHERE results.test_type = 'Freshness_Trend' AND ranked_test_runs.position = 1) AS freshness_has_errors,
+ BOOL_OR(results.result_status = 'Error') FILTER (WHERE results.test_type = 'Volume_Trend' AND ranked_test_runs.position = 1) AS volume_has_errors,
+ BOOL_OR(results.result_status = 'Error') FILTER (WHERE results.test_type = 'Schema_Drift' AND ranked_test_runs.position = 1) AS schema_has_errors,
+ BOOL_OR(results.result_status = 'Error') FILTER (WHERE results.test_type = 'Metric_Trend' AND ranked_test_runs.position = 1) AS metric_has_errors,
+ BOOL_AND(results.result_code = -1) FILTER (WHERE results.test_type = 'Freshness_Trend' AND ranked_test_runs.position = 1) AS freshness_is_training,
+ BOOL_AND(results.result_code = -1) FILTER (WHERE results.test_type = 'Volume_Trend' AND ranked_test_runs.position = 1) AS volume_is_training,
+ BOOL_AND(results.result_code = -1) FILTER (WHERE results.test_type = 'Metric_Trend' AND ranked_test_runs.position = 1) AS metric_is_training,
+ BOOL_OR(results.test_type = 'Freshness_Trend') IS NOT TRUE AS freshness_is_pending,
+ BOOL_OR(results.test_type = 'Volume_Trend') IS NOT TRUE AS volume_is_pending,
+ -- Schema monitor only creates results on schema changes (Failed)
+ -- Mark it as pending only if there are no results of any test type
+ BOOL_OR(results.test_time IS NOT NULL) IS NOT TRUE AS schema_is_pending,
+ BOOL_OR(results.test_type = 'Metric_Trend') IS NOT TRUE AS metric_is_pending
+ FROM ranked_test_runs
+ INNER JOIN test_results AS results
+ ON (results.test_run_id = ranked_test_runs.id)
+ WHERE ranked_test_runs.position <= ranked_test_runs.lookback
+ AND results.table_name IS NOT NULL
+ GROUP BY ranked_test_runs.table_group_id
+ ),
+ lookback_windows AS (
+ SELECT
+ table_group_id,
+ lookback,
+ MIN(test_starttime) FILTER (WHERE position = LEAST(lookback + 1, max_position)) AS lookback_start,
+ MAX(test_starttime) FILTER (WHERE position = 1) AS lookback_end
+ FROM (
+ SELECT *, MAX(position) OVER (PARTITION BY table_group_id) as max_position
+ FROM ranked_test_runs
+ ) pos
+ GROUP BY table_group_id, lookback
)
SELECT groups.id,
groups.table_groups_name,
@@ -240,10 +318,31 @@ def select_summary(cls, project_code: str, for_dashboard: bool = False) -> Itera
latest_profile.definite_ct AS latest_anomalies_definite_ct,
latest_profile.likely_ct AS latest_anomalies_likely_ct,
latest_profile.possible_ct AS latest_anomalies_possible_ct,
- latest_profile.dismissed_ct AS latest_anomalies_dismissed_ct
+ latest_profile.dismissed_ct AS latest_anomalies_dismissed_ct,
+ groups.monitor_test_suite_id AS monitor_test_suite_id,
+ lookback_windows.lookback AS monitor_lookback,
+ lookback_windows.lookback_start AS monitor_lookback_start,
+ lookback_windows.lookback_end AS monitor_lookback_end,
+ monitor_tables.freshness_anomalies AS monitor_freshness_anomalies,
+ monitor_tables.schema_anomalies AS monitor_schema_anomalies,
+ monitor_tables.volume_anomalies AS monitor_volume_anomalies,
+ monitor_tables.metric_anomalies AS monitor_metric_anomalies,
+ monitor_tables.freshness_has_errors AS monitor_freshness_has_errors,
+ monitor_tables.volume_has_errors AS monitor_volume_has_errors,
+ monitor_tables.schema_has_errors AS monitor_schema_has_errors,
+ monitor_tables.metric_has_errors AS monitor_metric_has_errors,
+ monitor_tables.freshness_is_training AS monitor_freshness_is_training,
+ monitor_tables.volume_is_training AS monitor_volume_is_training,
+ monitor_tables.metric_is_training AS monitor_metric_is_training,
+ monitor_tables.freshness_is_pending AS monitor_freshness_is_pending,
+ monitor_tables.volume_is_pending AS monitor_volume_is_pending,
+ monitor_tables.schema_is_pending AS monitor_schema_is_pending,
+ monitor_tables.metric_is_pending AS monitor_metric_is_pending
FROM table_groups AS groups
LEFT JOIN stats ON (groups.id = stats.table_groups_id)
LEFT JOIN latest_profile ON (groups.id = latest_profile.table_groups_id)
+ LEFT JOIN monitor_tables ON (groups.id = monitor_tables.table_group_id)
+ LEFT JOIN lookback_windows ON (groups.id = lookback_windows.table_group_id)
WHERE groups.project_code = :project_code
{"AND groups.include_in_dashboard IS TRUE" if for_dashboard else ""};
"""
@@ -331,12 +430,7 @@ def clear_cache(cls) -> bool:
cls.select_minimal_where.clear()
cls.select_summary.clear()
- def save(
- self,
- add_scorecard_definition: bool = False,
- add_monitor_test_suite: bool = False,
- monitor_schedule_timezone: str = "UTC",
- ) -> None:
+ def save(self, add_scorecard_definition: bool = False) -> None:
if self.id:
values = {
column.key: getattr(self, column.key, None)
@@ -349,38 +443,6 @@ def save(
db_session.commit()
else:
super().save()
- db_session = get_current_session()
-
if add_scorecard_definition:
ScoreDefinition.from_table_group(self).save()
-
- if add_monitor_test_suite:
- test_suite = TestSuite(
- project_code=self.project_code,
- test_suite=f"{self.table_groups_name} Monitor",
- connection_id=self.connection_id,
- table_groups_id=self.id,
- export_to_observability=False,
- dq_score_exclude=True,
- view_mode="Monitor",
- )
- test_suite.save()
-
- schedule_job = JobSchedule(
- project_code=self.project_code,
- key=RUN_TESTS_JOB_KEY,
- cron_expr="0 * * * *",
- cron_tz=monitor_schedule_timezone,
- args=[],
- kwargs={"test_suite_id": test_suite.id},
- )
- db_session.add(schedule_job)
-
- self.monitor_test_suite_id = test_suite.id
- db_session.execute(
- update(TableGroup)
- .where(TableGroup.id == self.id).values(monitor_test_suite_id=test_suite.id)
- )
- db_session.commit()
-
TableGroup.clear_cache()
diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py
index 195121b7..a18b8223 100644
--- a/testgen/common/models/test_definition.py
+++ b/testgen/common/models/test_definition.py
@@ -33,7 +33,21 @@
@dataclass
-class TestDefinitionSummary(EntityMinimal):
+class TestTypeSummary(EntityMinimal):
+ test_name_short: str
+ default_test_description: str
+ measure_uom: str
+ measure_uom_description: str
+ default_parm_columns: str
+ default_parm_prompts: str
+ default_parm_help: str
+ default_severity: str
+ test_scope: TestScope
+ usage_notes: str
+
+
+@dataclass
+class TestDefinitionSummary(TestTypeSummary):
id: UUID
table_groups_id: UUID
profile_run_id: UUID
@@ -67,25 +81,17 @@ class TestDefinitionSummary(EntityMinimal):
match_having_condition: str
custom_query: str
history_calculation: str
+ history_calculation_upper: str
history_lookback: int
- test_active: str
+ test_active: bool
test_definition_status: str
severity: str
- lock_refresh: str
+ lock_refresh: bool
last_auto_gen_date: datetime
profiling_as_of_date: datetime
last_manual_update: datetime
- export_to_observability: str
- test_name_short: str
- default_test_description: str
- measure_uom: str
- measure_uom_description: str
- default_parm_columns: str
- default_parm_prompts: str
- default_parm_help: str
- default_severity: str
- test_scope: TestScope
- usage_notes: str
+ export_to_observability: bool
+ prediction: str | None
@dataclass
@@ -143,6 +149,17 @@ class TestType(Entity):
usage_notes: str = Column(String)
active: str = Column(String)
+ _summary_columns = (
+ *[key for key in TestTypeSummary.__annotations__.keys() if key != "default_test_description"],
+ test_description.label("default_test_description"),
+ )
+
+ @classmethod
+ @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS)
+ def select_summary_where(cls, *clauses) -> Iterable[TestTypeSummary]:
+ results = cls._select_columns_where(cls._summary_columns, *clauses)
+ return [TestTypeSummary(**row) for row in results]
+
class TestDefinition(Entity):
__tablename__ = "test_definitions"
@@ -179,6 +196,7 @@ class TestDefinition(Entity):
match_groupby_names: str = Column(NullIfEmptyString)
match_having_condition: str = Column(NullIfEmptyString)
history_calculation: str = Column(NullIfEmptyString)
+ history_calculation_upper: str = Column(NullIfEmptyString)
history_lookback: int = Column(ZeroIfEmptyInteger, default=0)
test_mode: str = Column(String)
custom_query: str = Column(QueryString)
@@ -192,10 +210,12 @@ class TestDefinition(Entity):
profiling_as_of_date: datetime = Column(postgresql.TIMESTAMP)
last_manual_update: datetime = Column(UpdateTimestamp, nullable=False)
export_to_observability: bool = Column(YNString)
+ prediction: dict[str, dict[str, float]] | None = Column(postgresql.JSONB)
_default_order_by = (asc(func.lower(schema_name)), asc(func.lower(table_name)), asc(func.lower(column_name)), asc(test_type))
_summary_columns = (
- *[key for key in TestDefinitionSummary.__annotations__.keys() if key != "default_test_description"],
+ *TestDefinitionSummary.__annotations__.keys(),
+ *[key for key in TestTypeSummary.__annotations__.keys() if key != "default_test_description"],
TestType.test_description.label("default_test_description"),
)
_minimal_columns = TestDefinitionMinimal.__annotations__.keys()
@@ -211,6 +231,7 @@ class TestDefinition(Entity):
check_result,
last_auto_gen_date,
profiling_as_of_date,
+ prediction,
)
@classmethod
diff --git a/testgen/common/models/test_result.py b/testgen/common/models/test_result.py
index de96c97e..dd8d9ded 100644
--- a/testgen/common/models/test_result.py
+++ b/testgen/common/models/test_result.py
@@ -2,7 +2,7 @@
from collections import defaultdict
from uuid import UUID, uuid4
-from sqlalchemy import Boolean, Column, Enum, ForeignKey, String, and_, or_, select
+from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, or_, select
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import aliased
@@ -40,6 +40,7 @@ class TestResult(Entity):
status: TestResultStatus = Column("result_status", Enum(TestResultStatus))
message: str = Column("result_message", String)
+ result_code: int = Column(Integer)
# Note: not all table columns are implemented by this entity
@classmethod
@@ -50,22 +51,7 @@ def diff(cls, test_run_id_a: UUID, test_run_id_b: UUID) -> list[TestResultDiffTy
alias_a.status, alias_b.status, alias_b.test_definition_id,
).join(
alias_b,
- or_(
- and_(
- alias_a.auto_gen.is_(True),
- alias_b.auto_gen.is_(True),
- alias_a.test_suite_id == alias_b.test_suite_id,
- alias_a.schema_name == alias_b.schema_name,
- alias_a.table_name.isnot_distinct_from(alias_b.table_name),
- alias_a.column_names.isnot_distinct_from(alias_b.column_names),
- alias_a.test_type == alias_b.test_type,
- ),
- and_(
- alias_a.auto_gen.isnot(True),
- alias_b.auto_gen.isnot(True),
- alias_a.test_definition_id == alias_b.test_definition_id,
- ),
- ),
+ alias_a.test_definition_id == alias_b.test_definition_id,
full=True,
).where(
or_(alias_a.test_run_id == test_run_id_a, alias_a.test_run_id.is_(None)),
diff --git a/testgen/common/models/test_run.py b/testgen/common/models/test_run.py
index 01671315..3709328a 100644
--- a/testgen/common/models/test_run.py
+++ b/testgen/common/models/test_run.py
@@ -12,7 +12,9 @@
from testgen.common.models import get_current_session
from testgen.common.models.entity import Entity, EntityMinimal
-from testgen.common.models.test_result import TestResultStatus
+from testgen.common.models.project import Project
+from testgen.common.models.table_group import TableGroup
+from testgen.common.models.test_result import TestResult, TestResultStatus
from testgen.common.models.test_suite import TestSuite
from testgen.utils import is_uuid4
@@ -61,6 +63,20 @@ class TestRunSummary(EntityMinimal):
dismissed_ct: int
dq_score_testing: float
+
+@dataclass
+class TestRunMonitorSummary(EntityMinimal):
+ test_run_id: UUID
+ table_group_id: UUID
+ test_endtime: datetime
+ table_groups_name: str
+ project_name: str
+ freshness_anomalies: int
+ schema_anomalies: int
+ volume_anomalies: int
+ table_name: str | None = None
+
+
class LatestTestRun(NamedTuple):
id: str
run_time: datetime
@@ -240,7 +256,7 @@ def select_summary(
INNER JOIN test_suites ON (test_runs.test_suite_id = test_suites.id)
INNER JOIN table_groups ON (test_suites.table_groups_id = table_groups.id)
INNER JOIN projects ON (test_suites.project_code = projects.project_code)
- WHERE TRUE
+ WHERE test_suites.is_monitor IS NOT TRUE
{" AND test_suites.project_code = :project_code" if project_code else ""}
{" AND test_suites.table_groups_id = :table_group_id" if table_group_id else ""}
{" AND test_suites.id = :test_suite_id" if test_suite_id else ""}
@@ -257,6 +273,54 @@ def select_summary(
results = db_session.execute(text(query), params).mappings().all()
return [TestRunSummary(**row) for row in results]
+ def get_monitoring_summary(self, table_name: str | None = None) -> TestRunMonitorSummary:
+ freshness_anomalies = func.sum(case(
+ ((TestResult.test_type == "Table_Freshness") & (TestResult.result_code == 0), 1),
+ else_=0,
+ ))
+ schema_anomalies = func.sum(case(
+ ((TestResult.test_type == "Schema_Drift") & (TestResult.result_code == 0), 1),
+ else_=0,
+ ))
+ volume_anomalies = func.sum(case(
+ ((TestResult.test_type == "Volume_Trend") & (TestResult.result_code == 0), 1),
+ else_=0,
+ ))
+ projection = [
+ TestRun.id.label("test_run_id"),
+ TestRun.test_endtime,
+ TableGroup.id.label("table_group_id"),
+ TableGroup.table_groups_name,
+ Project.project_name,
+ freshness_anomalies.label("freshness_anomalies"),
+ schema_anomalies.label("schema_anomalies"),
+ volume_anomalies.label("volume_anomalies"),
+ ]
+ group_by = [
+ TestRun.id,
+ TestRun.test_endtime,
+ TableGroup.id,
+ TableGroup.table_groups_name,
+ Project.project_name,
+ ]
+ if table_name:
+ projection.append(TestResult.table_name)
+ group_by.append(TestResult.table_name)
+
+ query = (
+ select(*projection)
+ .join(TableGroup, TableGroup.monitor_test_suite_id == TestRun.test_suite_id)
+ .join(Project, Project.project_code == TableGroup.project_code)
+ .join(TestResult, TestResult.test_run_id == TestRun.id)
+ .where(
+ TestRun.id == self.id,
+ (TestResult.table_name == table_name) if table_name else True,
+ )
+ .group_by(*group_by)
+ )
+
+ return TestRunMonitorSummary(**get_current_session().execute(query).first())
+
@classmethod
def has_running_process(cls, ids: list[str]) -> bool:
query = select(func.count(cls.id)).where(cls.id.in_(ids), cls.status == "Running")
diff --git a/testgen/common/models/test_suite.py b/testgen/common/models/test_suite.py
index ce7d601c..a8c35b8d 100644
--- a/testgen/common/models/test_suite.py
+++ b/testgen/common/models/test_suite.py
@@ -1,10 +1,11 @@
+import enum
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from uuid import UUID, uuid4
import streamlit as st
-from sqlalchemy import BigInteger, Boolean, Column, ForeignKey, String, asc, func, text
+from sqlalchemy import BigInteger, Boolean, Column, Enum, ForeignKey, Integer, String, asc, func, text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import InstrumentedAttribute
@@ -14,6 +15,11 @@
from testgen.utils import is_uuid4
+class PredictSensitivity(enum.Enum):
+ low = "low"
+ medium = "medium"
+ high = "high"
+
@dataclass
class TestSuiteMinimal(EntityMinimal):
id: UUID
@@ -46,7 +52,6 @@ class TestSuiteSummary(EntityMinimal):
last_run_log_ct: int
last_run_dismissed_ct: int
-
class TestSuite(Entity):
__tablename__ = "test_suites"
@@ -63,7 +68,19 @@ class TestSuite(Entity):
component_name: str = Column(NullIfEmptyString)
last_complete_test_run_id: UUID = Column(postgresql.UUID(as_uuid=True))
dq_score_exclude: bool = Column(Boolean, default=False)
- view_mode: str | None = Column(NullIfEmptyString, default=None)
+ is_monitor: bool = Column(Boolean, default=False)
+ monitor_lookback: int | None = Column(Integer)
+ monitor_regenerate_freshness: bool = Column(Boolean, default=True)
+ predict_sensitivity: PredictSensitivity | None = Column(Enum(PredictSensitivity))
+ predict_min_lookback: int | None = Column(Integer)
+ predict_exclude_weekends: bool = Column(Boolean, default=False)
+ predict_holiday_codes: str | None = Column(String)
+
+ @property
+ def holiday_codes_list(self) -> list[str] | None:
+ if not self.predict_holiday_codes:
+ return None
+ return [code.strip() for code in self.predict_holiday_codes.split(",")]
_default_order_by = (asc(func.lower(test_suite)),)
_minimal_columns = TestSuiteMinimal.__annotations__.keys()
@@ -179,7 +196,8 @@ def select_summary(cls, project_code: str, table_group_id: str | UUID | None = N
ON (connections.connection_id = suites.connection_id)
LEFT JOIN table_groups AS groups
ON (groups.id = suites.table_groups_id)
- WHERE suites.project_code = :project_code
+ WHERE suites.is_monitor IS NOT TRUE
+ AND suites.project_code = :project_code
{"AND suites.table_groups_id = :table_group_id" if table_group_id else ""}
ORDER BY LOWER(suites.test_suite);
"""
diff --git a/testgen/common/notifications/monitor_run.py b/testgen/common/notifications/monitor_run.py
new file mode 100644
index 00000000..c3893aad
--- /dev/null
+++ b/testgen/common/notifications/monitor_run.py
@@ -0,0 +1,310 @@
+import logging
+
+from testgen.common.models import with_database_session
+from testgen.common.models.notification_settings import (
+ MonitorNotificationSettings,
+ MonitorNotificationTrigger,
+)
+from testgen.common.models.project import Project
+from testgen.common.models.settings import PersistedSetting
+from testgen.common.models.table_group import TableGroup
+from testgen.common.models.test_result import TestResult, TestResultStatus
+from testgen.common.models.test_run import TestRun
+from testgen.common.notifications.notifications import BaseNotificationTemplate
+from testgen.utils import log_and_swallow_exception
+
+LOG = logging.getLogger("testgen")
+
+
+class MonitorEmailTemplate(BaseNotificationTemplate):
+
+ def get_subject_template(self) -> str:
+ return (
+ "[TestGen] Monitors Alert: {{summary.table_groups_name}}"
+ "{{#if summary.table_name}} | {{summary.table_name}}{{/if}}"
+ ' | {{total_anomalies}} {{pluralize total_anomalies "anomaly" "anomalies"}}'
+ )
+
+ def get_title_template(self):
+ return "Monitors Alert: {{summary.table_groups_name}}"
+
+ def get_main_content_template(self):
+ return """
+
+
+
+ | Project |
+ {{summary.project_name}} |
+
+
+ | Table Group |
+ {{summary.table_groups_name}} |
+
+ {{#if summary.table_name}}
+
+ | Table |
+ {{summary.table_name}} |
+
+ {{/if}}
+
+ | Time |
+ {{format_dt summary.test_endtime}} |
+
+
+
+
+
+
+ | Anomalies Summary |
+
+ View on TestGen >
+ |
+
+ |
+
+
+
+
+
+
+ {{#each summary_tags}}
+ {{>summary_tag .}}
+ {{/each}}
+
+
+
+ |
+
+
+
+
+
+
+ | Table |
+ Monitor |
+ Details |
+
+ {{#each anomalies}}
+
+ | {{truncate 30 table_name}} |
+ {{type}} |
+ {{details}} |
+
+ {{/each}}
+
+
+ |
+
+ {{#if truncated}}
+ + {{truncated}} more
+ {{/if}}
+ |
+
+
+
+ """
+
+ def get_summary_tag_template(self):
+ return """
+
+
+
+ |
+
+ {{{badge_content}}}
+
+ |
+ {{type}} |
+
+
+ |
+ """
+
+ def get_extra_css_template(self) -> str:
+ return """
+ .tg-summary-bar {
+ width: 350px;
+ border-radius: 4px;
+ overflow: hidden;
+ }
+
+ .tg-summary-bar td {
+ height: 10px;
+ padding: 0;
+ line-height: 10px;
+ font-size: 0;
+ }
+
+ .tg-summary-bar--caption {
+ margin-top: 4px;
+ color: var(--caption-text-color);
+ font-size: 13px;
+ font-style: italic;
+ line-height: 1;
+ }
+
+ .tg-summary-bar--legend {
+ width: auto;
+ margin-right: 8px;
+ }
+
+ .tg-summary-bar--legend-dot {
+ margin-right: 2px;
+ font-style: normal;
+ }
+ """
+
+
+@log_and_swallow_exception
+@with_database_session
+def send_monitor_notifications(test_run: TestRun, result_list_ct=20):
+ notifications = list(MonitorNotificationSettings.select(
+ enabled=True,
+ test_suite_id=test_run.test_suite_id,
+ ))
+ if not notifications:
+ return
+
+ triggers = {MonitorNotificationTrigger.on_anomalies}
+ notifications = [ns for ns in notifications if ns.trigger in triggers]
+ if not notifications:
+ return
+
+ table_group, = TableGroup.select_where(TableGroup.monitor_test_suite_id == test_run.test_suite_id)
+ if not table_group:
+ return
+
+ project = Project.get(table_group.project_code)
+ for notification in notifications:
+ table_name = notification.settings.get("table_name")
+ test_results = list(TestResult.select_where(
+ TestResult.test_run_id == test_run.id,
+ (TestResult.table_name == table_name) if table_name else True,
+ ))
+ anomaly_results = [r for r in test_results if r.result_code == 0]
+
+ if len(anomaly_results) <= 0:
+ continue
+
+ anomalies = []
+ for test_result in anomaly_results:
+ label = _TEST_TYPE_LABELS.get(test_result.test_type)
+ details = test_result.message or "N/A"
+
+ if test_result.test_type == "Freshness_Trend":
+ parts = details.split(". ", 1)
+ message = parts[1].rstrip(".") if len(parts) > 1 else None
+ prefix = "Table updated" if "detected: Yes" in details else "No table update"
+ details = f"{prefix} - {message}" if message else prefix
+ elif test_result.test_type == "Metric_Trend":
+ label = f"{label}: {test_result.column_names}"
+
+ anomalies.append({
+ "table_name": test_result.table_name or "N/A",
+ "type": label,
+ "details": details,
+ })
+
+ view_in_testgen_url = "".join(
+ (
+ PersistedSetting.get("BASE_URL", ""),
+ "/monitors?project_code=",
+ str(table_group.project_code),
+ "&table_group_id=",
+ str(table_group.id),
+ "&table_name_filter=" if table_name else "",
+ table_name if table_name else "",
+ "&source=email",
+ )
+ )
+ try:
+ MonitorEmailTemplate().send(
+ notification.recipients,
+ {
+ "summary": {
+ "test_endtime": test_run.test_endtime,
+ "table_groups_name": table_group.table_groups_name,
+ "project_name": project.project_name,
+ "table_name": table_name,
+ },
+ "total_anomalies": len(anomaly_results),
+ "summary_tags": _build_summary_tags(test_results),
+ "anomalies": anomalies[:result_list_ct],
+ "truncated": max(len(anomalies) - result_list_ct, 0),
+ "view_in_testgen_url": view_in_testgen_url,
+ },
+ )
+ except Exception:
+ LOG.exception("Failed sending monitor email notifications")
+
+
+_TEST_TYPE_LABELS = {
+ "Freshness_Trend": "Freshness",
+ "Volume_Trend": "Volume",
+ "Schema_Drift": "Schema",
+ "Metric_Trend": "Metric",
+}
+
+_BADGE_BASE = "text-align: center; font-weight: bold; font-size: 13px;"
+_BADGE_STYLES = {
+ "anomaly": f"background-color: #EF5350; min-width: 15px; padding: 0 5px; border-radius: 10px; line-height: 20px; color: #ffffff; {_BADGE_BASE}",
+ "error": f"width: 20px; height: 20px; line-height: 20px; color: #FFA726; font-size: 16px; {_BADGE_BASE}",
+ "training": f"border: 2px solid #42A5F5; width: 20px; height: 20px; border-radius: 50%; line-height: 16px; color: #42A5F5; box-sizing: border-box; {_BADGE_BASE}",
+ "pending": f"width: 20px; height: 20px; line-height: 20px; color: #9E9E9E; {_BADGE_BASE}",
+ "passed": f"background-color: #9CCC65; width: 20px; height: 20px; border-radius: 50%; line-height: 21px; color: #ffffff; {_BADGE_BASE}",
+}
+_BADGE_CONTENT = {
+ "error": "⚠",
+ "training": "···",
+ "pending": "—",
+ "passed": "✓",
+}
+
+
+def _build_summary_tags(test_results: list[TestResult]) -> list[dict]:
+ has_any_results = bool(test_results)
+ tags = []
+ for type_key, label in _TEST_TYPE_LABELS.items():
+ type_results = [r for r in test_results if r.test_type == type_key]
+ anomaly_count = sum(1 for r in type_results if r.result_code == 0)
+ has_errors = any(r.status == TestResultStatus.Error for r in type_results)
+
+ # Schema Drift only creates results on detected changes, and has no training phase.
+ # Pending = no results of any type; no Schema results but other types ran = passed.
+ if type_key == "Schema_Drift":
+ is_pending = not has_any_results
+ is_training = False
+ else:
+ is_pending = not type_results
+ is_training = bool(type_results) and all(r.result_code == -1 for r in type_results)
+
+ # Priority matches UI: anomalies > errors > training > pending > passed
+ if anomaly_count > 0:
+ state = "anomaly"
+ elif has_errors:
+ state = "error"
+ elif is_training:
+ state = "training"
+ elif is_pending:
+ state = "pending"
+ else:
+ state = "passed"
+
+ tags.append({
+ "type": label,
+ "badge_style": _BADGE_STYLES[state],
+ "badge_content": str(anomaly_count) if state == "anomaly" else _BADGE_CONTENT[state],
+ })
+ return tags
diff --git a/testgen/common/notifications/notifications.py b/testgen/common/notifications/notifications.py
index e26fdab9..b4343e2e 100644
--- a/testgen/common/notifications/notifications.py
+++ b/testgen/common/notifications/notifications.py
@@ -393,7 +393,7 @@ def get_body_template(self) -> str: