From 6d048c43f0753d96976f3c9e72262cfe3b27d052 Mon Sep 17 00:00:00 2001
From: Daniel Standish <15932138+dstandish@users.noreply.github.com>
Date: Tue, 14 Jan 2025 16:19:26 -0800
Subject: [PATCH] Run the task with the configured dag bundle (#44752)
Ensures that dag runs are created with a reference to the bundle that was in effect at the time. And when a dag run has bundle info, the task will be run with that dag bundle version.
---
.../commands/remote_commands/task_command.py | 5 +-
airflow/dag_processing/processor.py | 1 +
airflow/executors/local_executor.py | 3 +-
airflow/executors/workloads.py | 27 +-
.../versions/0050_3_0_0_add_dagbundlemodel.py | 4 +
airflow/models/dag.py | 8 +
airflow/models/dagrun.py | 3 +
docs/apache-airflow/img/airflow_erd.sha256 | 2 +-
docs/apache-airflow/img/airflow_erd.svg | 1020 +++++++++--------
.../providers/edge/cli/edge_command.py | 3 +-
providers/tests/edge/cli/test_edge_command.py | 3 +-
.../edge/executors/test_edge_executor.py | 3 +-
.../airflow/sdk/api/datamodels/_generated.py | 5 +
.../src/airflow/sdk/execution_time/comms.py | 4 +-
.../airflow/sdk/execution_time/supervisor.py | 31 +-
.../airflow/sdk/execution_time/task_runner.py | 13 +-
task_sdk/tests/execution_time/conftest.py | 5 +-
.../tests/execution_time/test_supervisor.py | 67 +-
.../tests/execution_time/test_task_runner.py | 37 +-
tests/api_fastapi/common/test_exceptions.py | 6 +-
.../remote_commands/test_task_command.py | 100 +-
tests/executors/test_local_executor.py | 16 +-
tests/models/test_dag.py | 2 +-
23 files changed, 751 insertions(+), 617 deletions(-)
diff --git a/airflow/cli/commands/remote_commands/task_command.py b/airflow/cli/commands/remote_commands/task_command.py
index 6c591801515d5..2329ee25bccaa 100644
--- a/airflow/cli/commands/remote_commands/task_command.py
+++ b/airflow/cli/commands/remote_commands/task_command.py
@@ -245,6 +245,9 @@ def _get_ti(
# we do refresh_from_task so that if TI has come back via RPC, we ensure that ti.task
# is the original task object and not the result of the round trip
ti.refresh_from_task(task, pool_override=pool)
+
+ ti.dag_model # we must ensure dag model is loaded eagerly for bundle info
+
return ti, dr_created
@@ -286,7 +289,7 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None:
if executor.queue_workload.__func__ is not BaseExecutor.queue_workload: # type: ignore[attr-defined]
from airflow.executors import workloads
- workload = workloads.ExecuteTask.make(ti, dag_path=dag.relative_fileloc)
+ workload = workloads.ExecuteTask.make(ti, dag_rel_path=dag.relative_fileloc)
with create_session() as session:
executor.queue_workload(workload, session)
else:
diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py
index c175f3f68c726..5d73c50e7d736 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -203,6 +203,7 @@ class DagFileProcessorProcess(WatchedSubprocess):
@classmethod
def start( # type: ignore[override]
cls,
+ *,
path: str | os.PathLike[str],
callbacks: list[CallbackRequest],
target: Callable[[], None] = _parse_file_entrypoint,
diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py
index 9f4f9c617c5f1..62d848cf66976 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -116,7 +116,8 @@ def _execute_work(log: logging.Logger, workload: workloads.ExecuteTask) -> None:
supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
ti=workload.ti, # type: ignore[arg-type]
- dag_path=workload.dag_path,
+ dag_rel_path=workload.dag_rel_path,
+ bundle_info=workload.bundle_info,
token=workload.token,
server=conf.get("workers", "execution_api_server_url", fallback="http://localhost:9091/execution/"),
log_path=workload.log_path,
diff --git a/airflow/executors/workloads.py b/airflow/executors/workloads.py
index 13331f9b5793a..9a5e425ef887d 100644
--- a/airflow/executors/workloads.py
+++ b/airflow/executors/workloads.py
@@ -39,6 +39,13 @@ class BaseActivity(BaseModel):
"""The identity token for this workload"""
+class BundleInfo(BaseModel):
+ """Schema for telling task which bundle to run with."""
+
+ name: str
+ version: str | None = None
+
+
class TaskInstance(BaseModel):
"""Schema for TaskInstance with minimal required fields needed for Executors and Task SDK."""
@@ -73,30 +80,30 @@ class ExecuteTask(BaseActivity):
ti: TaskInstance
"""The TaskInstance to execute"""
- dag_path: os.PathLike[str]
+ dag_rel_path: os.PathLike[str]
"""The filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`)"""
+ bundle_info: BundleInfo
+
log_path: str | None
"""The rendered relative log filename template the task logs should be written to"""
kind: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask")
@classmethod
- def make(cls, ti: TIModel, dag_path: Path | None = None) -> ExecuteTask:
+ def make(cls, ti: TIModel, dag_rel_path: Path | None = None) -> ExecuteTask:
from pathlib import Path
from airflow.utils.helpers import log_filename_template_renderer
ser_ti = TaskInstance.model_validate(ti, from_attributes=True)
-
- dag_path = dag_path or Path(ti.dag_run.dag_model.relative_fileloc)
-
- if dag_path and not dag_path.is_absolute():
- # TODO: What about multiple dag sub folders
- dag_path = "DAGS_FOLDER" / dag_path
-
+ bundle_info = BundleInfo.model_construct(
+ name=ti.dag_model.bundle_name,
+ version=ti.dag_run.bundle_version,
+ )
+ path = dag_rel_path or Path(ti.dag_run.dag_model.relative_fileloc)
fname = log_filename_template_renderer()(ti=ti)
- return cls(ti=ser_ti, dag_path=dag_path, token="", log_path=fname)
+ return cls(ti=ser_ti, dag_rel_path=path, token="", log_path=fname, bundle_info=bundle_info)
All = Union[ExecuteTask]
diff --git a/airflow/migrations/versions/0050_3_0_0_add_dagbundlemodel.py b/airflow/migrations/versions/0050_3_0_0_add_dagbundlemodel.py
index 7e6cb756c5cb2..5322be0abe763 100644
--- a/airflow/migrations/versions/0050_3_0_0_add_dagbundlemodel.py
+++ b/airflow/migrations/versions/0050_3_0_0_add_dagbundlemodel.py
@@ -53,6 +53,8 @@ def upgrade():
batch_op.create_foreign_key(
batch_op.f("dag_bundle_name_fkey"), "dag_bundle", ["bundle_name"], ["name"]
)
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("bundle_version", sa.String(length=250), nullable=True))
def downgrade():
@@ -60,5 +62,7 @@ def downgrade():
batch_op.drop_constraint(batch_op.f("dag_bundle_name_fkey"), type_="foreignkey")
batch_op.drop_column("latest_bundle_version")
batch_op.drop_column("bundle_name")
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
+ batch_op.drop_column("bundle_version")
op.drop_table("dag_bundle")
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 82b4ca70819b9..ffd3d2dba56cc 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -262,6 +262,13 @@ def _create_orm_dagrun(
session,
triggered_by,
):
+ bundle_version = session.scalar(
+ select(
+ DagModel.latest_bundle_version,
+ ).where(
+ DagModel.dag_id == dag.dag_id,
+ )
+ )
run = DagRun(
dag_id=dag_id,
run_id=run_id,
@@ -276,6 +283,7 @@ def _create_orm_dagrun(
data_interval=data_interval,
triggered_by=triggered_by,
backfill_id=backfill_id,
+ bundle_version=bundle_version,
)
# Load defaults into the following two fields to ensure result can be serialized detached
run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id))))
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 15d275da1a465..e7aad18672d5c 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -166,6 +166,7 @@ class DagRun(Base, LoggingMixin):
"""
dag_version_id = Column(UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"))
dag_version = relationship("DagVersion", back_populates="dag_runs")
+ bundle_version = Column(StringID())
# Remove this `if` after upgrading Sphinx-AutoAPI
if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
@@ -238,6 +239,7 @@ def __init__(
triggered_by: DagRunTriggeredByType | None = None,
backfill_id: int | None = None,
dag_version: DagVersion | None = None,
+ bundle_version: str | None = None,
):
if data_interval is None:
# Legacy: Only happen for runs created prior to Airflow 2.2.
@@ -245,6 +247,7 @@ def __init__(
else:
self.data_interval_start, self.data_interval_end = data_interval
+ self.bundle_version = bundle_version
self.dag_id = dag_id
self.run_id = run_id
self.logical_date = logical_date
diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256
index 3616222fd880c..86afe05d37001 100644
--- a/docs/apache-airflow/img/airflow_erd.sha256
+++ b/docs/apache-airflow/img/airflow_erd.sha256
@@ -1 +1 @@
-ca59d711e6304f8bfdb25f49339d455602430dd6b880e420869fc892faef0596
\ No newline at end of file
+00d5d138d0773a6b700ada4650f5c60cc3972afefd3945ea434dea50abfda834
\ No newline at end of file
diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg
index 24f75b3247093..9fe57986ed3b3 100644
--- a/docs/apache-airflow/img/airflow_erd.svg
+++ b/docs/apache-airflow/img/airflow_erd.svg
@@ -25,250 +25,250 @@
[VARCHAR(2000)]
NOT NULL
-
+
-slot_pool
-
-slot_pool
-
-id
-
- [INTEGER]
- NOT NULL
-
-description
-
- [TEXT]
-
-include_deferred
-
- [BOOLEAN]
- NOT NULL
-
-pool
-
- [VARCHAR(256)]
-
-slots
-
- [INTEGER]
+log
+
+log
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+dag_id
+
+ [VARCHAR(250)]
+
+dttm
+
+ [TIMESTAMP]
+
+event
+
+ [VARCHAR(60)]
+
+extra
+
+ [TEXT]
+
+logical_date
+
+ [TIMESTAMP]
+
+map_index
+
+ [INTEGER]
+
+owner
+
+ [VARCHAR(500)]
+
+owner_display_name
+
+ [VARCHAR(500)]
+
+run_id
+
+ [VARCHAR(250)]
+
+task_id
+
+ [VARCHAR(250)]
+
+try_number
+
+ [INTEGER]
-
+
-callback_request
-
-callback_request
-
-id
-
- [INTEGER]
- NOT NULL
-
-callback_data
-
- [JSON]
- NOT NULL
-
-callback_type
-
- [VARCHAR(20)]
- NOT NULL
-
-created_at
-
- [TIMESTAMP]
- NOT NULL
-
-priority_weight
-
- [INTEGER]
- NOT NULL
+slot_pool
+
+slot_pool
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+description
+
+ [TEXT]
+
+include_deferred
+
+ [BOOLEAN]
+ NOT NULL
+
+pool
+
+ [VARCHAR(256)]
+
+slots
+
+ [INTEGER]
-
+
-log
-
-log
-
-id
-
- [INTEGER]
- NOT NULL
-
-dag_id
-
- [VARCHAR(250)]
-
-dttm
-
- [TIMESTAMP]
-
-event
-
- [VARCHAR(60)]
-
-extra
-
- [TEXT]
-
-logical_date
-
- [TIMESTAMP]
-
-map_index
-
- [INTEGER]
-
-owner
-
- [VARCHAR(500)]
-
-owner_display_name
-
- [VARCHAR(500)]
-
-run_id
-
- [VARCHAR(250)]
-
-task_id
-
- [VARCHAR(250)]
-
-try_number
-
- [INTEGER]
-
-
-
-job
-
-job
-
-id
-
- [INTEGER]
- NOT NULL
-
-dag_id
-
- [VARCHAR(250)]
-
-end_date
-
- [TIMESTAMP]
-
-executor_class
-
- [VARCHAR(500)]
-
-hostname
-
- [VARCHAR(500)]
-
-job_type
-
- [VARCHAR(30)]
-
-latest_heartbeat
-
- [TIMESTAMP]
-
-start_date
-
- [TIMESTAMP]
-
-state
-
- [VARCHAR(20)]
-
-unixname
-
- [VARCHAR(1000)]
+callback_request
+
+callback_request
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+callback_data
+
+ [JSON]
+ NOT NULL
+
+callback_type
+
+ [VARCHAR(20)]
+ NOT NULL
+
+created_at
+
+ [TIMESTAMP]
+ NOT NULL
+
+priority_weight
+
+ [INTEGER]
+ NOT NULL
-
+
connection
-
-connection
-
-id
-
- [INTEGER]
- NOT NULL
-
-conn_id
-
- [VARCHAR(250)]
- NOT NULL
-
-conn_type
-
- [VARCHAR(500)]
- NOT NULL
-
-description
-
- [TEXT]
-
-extra
-
- [TEXT]
-
-host
-
- [VARCHAR(500)]
-
-is_encrypted
-
- [BOOLEAN]
-
-is_extra_encrypted
-
- [BOOLEAN]
-
-login
-
- [TEXT]
-
-password
-
- [TEXT]
-
-port
-
- [INTEGER]
-
-schema
-
- [VARCHAR(500)]
+
+connection
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+conn_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+conn_type
+
+ [VARCHAR(500)]
+ NOT NULL
+
+description
+
+ [TEXT]
+
+extra
+
+ [TEXT]
+
+host
+
+ [VARCHAR(500)]
+
+is_encrypted
+
+ [BOOLEAN]
+
+is_extra_encrypted
+
+ [BOOLEAN]
+
+login
+
+ [TEXT]
+
+password
+
+ [TEXT]
+
+port
+
+ [INTEGER]
+
+schema
+
+ [VARCHAR(500)]
-
+
variable
-
-variable
-
-id
-
- [INTEGER]
- NOT NULL
-
-description
-
- [TEXT]
-
-is_encrypted
-
- [BOOLEAN]
-
-key
-
- [VARCHAR(250)]
-
-val
-
- [TEXT]
+
+variable
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+description
+
+ [TEXT]
+
+is_encrypted
+
+ [BOOLEAN]
+
+key
+
+ [VARCHAR(250)]
+
+val
+
+ [TEXT]
+
+
+
+job
+
+job
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+dag_id
+
+ [VARCHAR(250)]
+
+end_date
+
+ [TIMESTAMP]
+
+executor_class
+
+ [VARCHAR(500)]
+
+hostname
+
+ [VARCHAR(500)]
+
+job_type
+
+ [VARCHAR(30)]
+
+latest_heartbeat
+
+ [TIMESTAMP]
+
+start_date
+
+ [TIMESTAMP]
+
+state
+
+ [VARCHAR(20)]
+
+unixname
+
+ [VARCHAR(1000)]
@@ -1043,30 +1043,30 @@
task_instance--task_map
-
+
0..N
-1
+1
task_instance--task_map
-
+
0..N
-1
+1
task_instance--task_map
-
+
0..N
-1
+1
task_instance--task_map
-
+
0..N
-1
+1
@@ -1116,30 +1116,30 @@
task_instance--xcom
-
-0..N
-1
+
+0..N
+1
task_instance--xcom
-
-0..N
-1
+
+0..N
+1
task_instance--xcom
-
-0..N
-1
+
+0..N
+1
task_instance--xcom
-
-0..N
-1
+
+0..N
+1
@@ -1727,41 +1727,41 @@
deadline
-
-deadline
-
-id
-
- [UUID]
- NOT NULL
-
-callback
-
- [VARCHAR(500)]
- NOT NULL
-
-callback_kwargs
-
- [JSON]
-
-dag_id
-
- [VARCHAR(250)]
-
-dagrun_id
-
- [INTEGER]
-
-deadline
-
- [TIMESTAMP]
- NOT NULL
+
+deadline
+
+id
+
+ [UUID]
+ NOT NULL
+
+callback
+
+ [VARCHAR(500)]
+ NOT NULL
+
+callback_kwargs
+
+ [JSON]
+
+dag_id
+
+ [VARCHAR(250)]
+
+dagrun_id
+
+ [INTEGER]
+
+deadline
+
+ [TIMESTAMP]
+ NOT NULL
dag--deadline
-
-0..N
+
+0..N
{0,1}
@@ -1774,104 +1774,108 @@
dag_run
-
-dag_run
-
-id
-
- [INTEGER]
- NOT NULL
-
-backfill_id
-
- [INTEGER]
-
-clear_number
-
- [INTEGER]
- NOT NULL
-
-conf
-
- [JSONB]
-
-creating_job_id
-
- [INTEGER]
-
-dag_id
-
- [VARCHAR(250)]
- NOT NULL
-
-dag_version_id
-
- [UUID]
-
-data_interval_end
-
- [TIMESTAMP]
-
-data_interval_start
-
- [TIMESTAMP]
-
-end_date
-
- [TIMESTAMP]
-
-external_trigger
-
- [BOOLEAN]
-
-last_scheduling_decision
-
- [TIMESTAMP]
-
-log_template_id
-
- [INTEGER]
-
-logical_date
-
- [TIMESTAMP]
- NOT NULL
-
-queued_at
-
- [TIMESTAMP]
-
-run_id
-
- [VARCHAR(250)]
- NOT NULL
-
-run_type
-
- [VARCHAR(50)]
- NOT NULL
-
-start_date
-
- [TIMESTAMP]
-
-state
-
- [VARCHAR(50)]
-
-triggered_by
-
- [VARCHAR(50)]
-
-updated_at
-
- [TIMESTAMP]
+
+dag_run
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+backfill_id
+
+ [INTEGER]
+
+bundle_version
+
+ [VARCHAR(250)]
+
+clear_number
+
+ [INTEGER]
+ NOT NULL
+
+conf
+
+ [JSONB]
+
+creating_job_id
+
+ [INTEGER]
+
+dag_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+dag_version_id
+
+ [UUID]
+
+data_interval_end
+
+ [TIMESTAMP]
+
+data_interval_start
+
+ [TIMESTAMP]
+
+end_date
+
+ [TIMESTAMP]
+
+external_trigger
+
+ [BOOLEAN]
+
+last_scheduling_decision
+
+ [TIMESTAMP]
+
+log_template_id
+
+ [INTEGER]
+
+logical_date
+
+ [TIMESTAMP]
+ NOT NULL
+
+queued_at
+
+ [TIMESTAMP]
+
+run_id
+
+ [VARCHAR(250)]
+ NOT NULL
+
+run_type
+
+ [VARCHAR(50)]
+ NOT NULL
+
+start_date
+
+ [TIMESTAMP]
+
+state
+
+ [VARCHAR(50)]
+
+triggered_by
+
+ [VARCHAR(50)]
+
+updated_at
+
+ [TIMESTAMP]
dag_version--dag_run
-
-0..N
+
+0..N
{0,1}
@@ -1971,121 +1975,121 @@
dag_run--dagrun_asset_event
-
-0..N
-1
+
+0..N
+1
dag_run--task_instance
-
-0..N
-1
+
+0..N
+1
dag_run--task_instance
-
-0..N
-1
+
+0..N
+1
dag_run--deadline
-
-0..N
-{0,1}
+
+0..N
+{0,1}
backfill_dag_run
-
-backfill_dag_run
-
-id
-
- [INTEGER]
- NOT NULL
-
-backfill_id
-
- [INTEGER]
- NOT NULL
-
-dag_run_id
-
- [INTEGER]
-
-exception_reason
-
- [VARCHAR(250)]
-
-logical_date
-
- [TIMESTAMP]
- NOT NULL
-
-sort_ordinal
-
- [INTEGER]
- NOT NULL
+
+backfill_dag_run
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+backfill_id
+
+ [INTEGER]
+ NOT NULL
+
+dag_run_id
+
+ [INTEGER]
+
+exception_reason
+
+ [VARCHAR(250)]
+
+logical_date
+
+ [TIMESTAMP]
+ NOT NULL
+
+sort_ordinal
+
+ [INTEGER]
+ NOT NULL
dag_run--backfill_dag_run
-
-0..N
-{0,1}
+
+0..N
+{0,1}
dag_run_note
-
-dag_run_note
-
-dag_run_id
-
- [INTEGER]
- NOT NULL
-
-content
-
- [VARCHAR(1000)]
-
-created_at
-
- [TIMESTAMP]
- NOT NULL
-
-updated_at
-
- [TIMESTAMP]
- NOT NULL
-
-user_id
-
- [VARCHAR(128)]
+
+dag_run_note
+
+dag_run_id
+
+ [INTEGER]
+ NOT NULL
+
+content
+
+ [VARCHAR(1000)]
+
+created_at
+
+ [TIMESTAMP]
+ NOT NULL
+
+updated_at
+
+ [TIMESTAMP]
+ NOT NULL
+
+user_id
+
+ [VARCHAR(128)]
dag_run--dag_run_note
-
-1
-1
+
+1
+1
dag_run--task_reschedule
-
-0..N
-1
+
+0..N
+1
dag_run--task_reschedule
-
-0..N
-1
+
+0..N
+1
@@ -2116,9 +2120,9 @@
log_template--dag_run
-
-0..N
-{0,1}
+
+0..N
+{0,1}
@@ -2182,16 +2186,16 @@
backfill--dag_run
-
-0..N
-{0,1}
+
+0..N
+{0,1}
backfill--backfill_dag_run
-
-0..N
-1
+
+0..N
+1
@@ -2311,28 +2315,28 @@
ab_user_role
-
-ab_user_role
-
-id
-
- [INTEGER]
- NOT NULL
-
-role_id
-
- [INTEGER]
-
-user_id
-
- [INTEGER]
+
+ab_user_role
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+role_id
+
+ [INTEGER]
+
+user_id
+
+ [INTEGER]
ab_user--ab_user_role
-
-0..N
-{0,1}
+
+0..N
+{0,1}
@@ -2422,28 +2426,28 @@
ab_permission_view_role
-
-ab_permission_view_role
-
-id
-
- [INTEGER]
- NOT NULL
-
-permission_view_id
-
- [INTEGER]
-
-role_id
-
- [INTEGER]
+
+ab_permission_view_role
+
+id
+
+ [INTEGER]
+ NOT NULL
+
+permission_view_id
+
+ [INTEGER]
+
+role_id
+
+ [INTEGER]
ab_permission_view--ab_permission_view_role
-
-0..N
-{0,1}
+
+0..N
+{0,1}
@@ -2487,16 +2491,16 @@
ab_role--ab_user_role
-
-0..N
-{0,1}
+
+0..N
+{0,1}
ab_role--ab_permission_view_role
-
-0..N
-{0,1}
+
+0..N
+{0,1}
diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py b/providers/src/airflow/providers/edge/cli/edge_command.py
index 82fedc88e0714..d0dbd66b241c4 100644
--- a/providers/src/airflow/providers/edge/cli/edge_command.py
+++ b/providers/src/airflow/providers/edge/cli/edge_command.py
@@ -223,7 +223,8 @@ def _run_job_via_supervisor(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
# Same like in airflow/executors/local_executor.py:_execute_work()
ti=workload.ti, # type: ignore[arg-type]
- dag_path=workload.dag_path,
+ dag_rel_path=workload.dag_rel_path,
+ bundle_info=workload.bundle_info,
token=workload.token,
server=conf.get(
"workers", "execution_api_server_url", fallback="http://localhost:9091/execution/"
diff --git a/providers/tests/edge/cli/test_edge_command.py b/providers/tests/edge/cli/test_edge_command.py
index e4161901dbca6..b1b719444baf0 100644
--- a/providers/tests/edge/cli/test_edge_command.py
+++ b/providers/tests/edge/cli/test_edge_command.py
@@ -52,8 +52,9 @@
"queue": "default",
"priority_weight": 1,
},
- "dag_path": "dummy.py",
+ "dag_rel_path": "dummy.py",
"log_path": "dummy.log",
+ "bundle_info": {"name": "hello", "version": "abc"},
}
if AIRFLOW_V_3_0_PLUS
else ["test", "command"] # Airflow 2.10
diff --git a/providers/tests/edge/executors/test_edge_executor.py b/providers/tests/edge/executors/test_edge_executor.py
index 2d22744b5f1a5..3a5e6b18d69a3 100644
--- a/providers/tests/edge/executors/test_edge_executor.py
+++ b/providers/tests/edge/executors/test_edge_executor.py
@@ -301,8 +301,9 @@ def test_queue_workload(self):
queue="default",
priority_weight=1,
),
- dag_path="dummy.py",
+ dag_rel_path="dummy.py",
log_path="dummy.log",
+ bundle_info={"name": "n/a", "version": "no matter"},
)
executor.queue_workload(workload=workload)
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index 785037b6bfbaa..a8b478d07f029 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -168,6 +168,11 @@ class XComResponse(BaseModel):
value: Annotated[Any, Field(title="Value")]
+class BundleInfo(BaseModel):
+ name: str
+ version: str | None = None
+
+
class TaskInstance(BaseModel):
"""
Schema for TaskInstance model with minimal required fields needed for Runtime.
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py
index e1007d1f2fc05..b6874d47f090c 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -50,6 +50,7 @@
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from airflow.sdk.api.datamodels._generated import (
+ BundleInfo,
ConnectionResponse,
TaskInstance,
TerminalTIState,
@@ -66,7 +67,8 @@ class StartupDetails(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
ti: TaskInstance
- file: str
+ dag_rel_path: str
+ bundle_info: BundleInfo
requests_fd: int
"""
The channel for the task to send requests over.
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 002b156cee2db..32895d36524d8 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -80,6 +80,7 @@
if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger, WrappedLogger
+ from airflow.executors.workloads import BundleInfo
from airflow.typing_compat import Self
@@ -316,6 +317,7 @@ class WatchedSubprocess:
@classmethod
def start(
cls,
+ *,
target: Callable[[], None] = _subprocess_main,
logger: FilteringBoundLogger | None = None,
**constructor_kwargs,
@@ -574,8 +576,10 @@ class ActivitySubprocess(WatchedSubprocess):
@classmethod
def start( # type: ignore[override]
cls,
- path: str | os.PathLike[str],
+ *,
what: TaskInstance,
+ dag_rel_path: str | os.PathLike[str],
+ bundle_info,
client: Client,
target: Callable[[], None] = _subprocess_main,
logger: FilteringBoundLogger | None = None,
@@ -584,10 +588,10 @@ def start( # type: ignore[override]
"""Fork and start a new subprocess to execute the given task."""
proc: Self = super().start(id=what.id, client=client, target=target, logger=logger, **kwargs)
# Tell the task process what it needs to do!
- proc._on_child_started(what, path)
+ proc._on_child_started(ti=what, dag_rel_path=dag_rel_path, bundle_info=bundle_info)
return proc
- def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str]):
+ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[str], bundle_info):
"""Send startup message to the subprocess."""
try:
# We've forked, but the task won't start doing anything until we send it the StartupDetails
@@ -602,7 +606,8 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str]):
msg = StartupDetails.model_construct(
ti=ti,
- file=os.fspath(path),
+ dag_rel_path=os.fspath(dag_rel_path),
+ bundle_info=bundle_info,
requests_fd=self._requests_fd,
ti_context=ti_context,
)
@@ -879,7 +884,8 @@ def forward_to_log(target_log: FilteringBoundLogger, level: int) -> Generator[No
def supervise(
*,
ti: TaskInstance,
- dag_path: str | os.PathLike[str],
+ bundle_info: BundleInfo,
+ dag_rel_path: str | os.PathLike[str],
token: str,
server: str | None = None,
dry_run: bool = False,
@@ -902,14 +908,9 @@ def supervise(
if not client and ((not server) ^ dry_run):
raise ValueError(f"Can only specify one of {server=} or {dry_run=}")
- if not dag_path:
+ if not dag_rel_path:
raise ValueError("dag_path is required")
- if (str_path := os.fspath(dag_path)).startswith("DAGS_FOLDER/"):
- from airflow.settings import DAGS_FOLDER
-
- dag_path = str_path.replace("DAGS_FOLDER/", DAGS_FOLDER + "/", 1)
-
if not client:
limits = httpx.Limits(max_keepalive_connections=1, max_connections=10)
client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token)
@@ -932,7 +933,13 @@ def supervise(
processors = logging_processors(enable_pretty_log=pretty_logs)[0]
logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind()
- process = ActivitySubprocess.start(dag_path, ti, client=client, logger=logger)
+ process = ActivitySubprocess.start(
+ dag_rel_path=dag_rel_path,
+ what=ti,
+ client=client,
+ logger=logger,
+ bundle_info=bundle_info,
+ )
exit_code = process.wait()
end = time.monotonic()
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 2beff84e8f8d7..796c9f57d104a 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -24,12 +24,14 @@
from collections.abc import Iterable, Mapping
from datetime import datetime, timezone
from io import FileIO
+from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
import attrs
import structlog
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
+from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState, TIRunContext
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
@@ -301,8 +303,15 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
from airflow.models.dagbag import DagBag
+ bundle_info = what.bundle_info
+ bundle_instance = DagBundlesManager().get_bundle(
+ name=bundle_info.name,
+ version=bundle_info.version,
+ )
+
+ dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
bag = DagBag(
- dag_folder=what.file,
+ dag_folder=dag_absolute_path,
include_examples=False,
safe_mode=False,
load_op_links=False,
@@ -399,7 +408,7 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
log = structlog.get_logger(logger_name="task")
# TODO: set the "magic loop" context vars for parsing
ti = parse(msg)
- log.debug("DAG file parsed", file=msg.file)
+ log.debug("DAG file parsed", file=msg.dag_rel_path)
else:
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py
index 507d17aa3508e..641f14817d899 100644
--- a/task_sdk/tests/execution_time/conftest.py
+++ b/task_sdk/tests/execution_time/conftest.py
@@ -117,7 +117,7 @@ def execute(self, context):
from uuid6 import uuid7
from airflow.sdk.api.datamodels._generated import TaskInstance
- from airflow.sdk.execution_time.comms import StartupDetails
+ from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails
def _create_task_instance(
task: BaseOperator,
@@ -148,7 +148,8 @@ def _create_task_instance(
ti=TaskInstance(
id=ti_id, task_id=task.task_id, dag_id=dag_id, run_id=run_id, try_number=try_number
),
- file="",
+ dag_rel_path="",
+ bundle_info=BundleInfo.model_construct(name="anything", version="any"),
requests_fd=0,
ti_context=ti_context,
)
diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py
index c5879971a63d4..12c3455ccfe1e 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import inspect
+import json
import logging
import os
import selectors
@@ -27,7 +28,7 @@
from operator import attrgetter
from time import sleep
from typing import TYPE_CHECKING
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, patch
import httpx
import psutil
@@ -35,6 +36,7 @@
from pytest_unordered import unordered
from uuid6 import uuid7
+from airflow.executors.workloads import BundleInfo
from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
@@ -57,6 +59,7 @@
from airflow.utils import timezone, timezone as tz
from task_sdk.tests.api.test_client import make_client
+from task_sdk.tests.execution_time.test_task_runner import FAKE_BUNDLE
if TYPE_CHECKING:
import kgb
@@ -69,6 +72,20 @@ def lineno():
return inspect.currentframe().f_back.f_lineno
+def local_dag_bundle_cfg(path, name="my-bundle"):
+ return {
+ "AIRFLOW__DAG_BUNDLES__BACKENDS": json.dumps(
+ [
+ {
+ "name": name,
+ "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle",
+ "kwargs": {"local_folder": str(path), "refresh_interval": 1},
+ }
+ ]
+ )
+ }
+
+
@pytest.mark.usefixtures("disable_capturing")
class TestWatchedSubprocess:
def test_reading_from_pipes(self, captured_logs, time_machine):
@@ -100,7 +117,8 @@ def subprocess_main():
time_machine.move_to(instant, tick=False)
proc = ActivitySubprocess.start(
- path=os.devnull,
+ dag_rel_path=os.devnull,
+ bundle_info=FAKE_BUNDLE,
what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
@@ -167,7 +185,8 @@ def subprocess_main():
os.kill(os.getpid(), signal.SIGKILL)
proc = ActivitySubprocess.start(
- path=os.devnull,
+ dag_rel_path=os.devnull,
+ bundle_info=FAKE_BUNDLE,
what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
@@ -190,7 +209,8 @@ def subprocess_main():
raise RuntimeError("Fake syntax error")
proc = ActivitySubprocess.start(
- path=os.devnull,
+ dag_rel_path=os.devnull,
+ bundle_info=FAKE_BUNDLE,
what=TaskInstance(
id=uuid7(),
task_id="b",
@@ -226,7 +246,8 @@ def subprocess_main():
ti_id = uuid7()
spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat)
proc = ActivitySubprocess.start(
- path=os.devnull,
+ dag_rel_path=os.devnull,
+ bundle_info=FAKE_BUNDLE,
what=TaskInstance(
id=ti_id,
task_id="b",
@@ -248,7 +269,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine):
instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901)
time_machine.move_to(instant, tick=False)
- dagfile_path = test_dags_dir / "super_basic_run.py"
+ dagfile_path = test_dags_dir
ti = TaskInstance(
id=uuid7(),
task_id="hello",
@@ -256,8 +277,17 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine):
run_id="c",
try_number=1,
)
- # Assert Exit Code is 0
- assert supervise(ti=ti, dag_path=dagfile_path, token="", server="", dry_run=True) == 0, captured_logs
+ bundle_info = BundleInfo.model_construct(name="my-bundle", version=None)
+ with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)):
+ exit_code = supervise(
+ ti=ti,
+ dag_rel_path=dagfile_path,
+ token="",
+ server="",
+ dry_run=True,
+ bundle_info=bundle_info,
+ )
+ assert exit_code == 0, captured_logs
# We should have a log from the task!
assert {
@@ -281,7 +311,6 @@ def test_supervise_handles_deferred_task(
ti = TaskInstance(
id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="d", try_number=1
)
- dagfile_path = test_dags_dir / "super_basic_deferred_run.py"
# Create a mock client to assert calls to the client
# We assume the implementation of the client is correct and only need to check the calls
@@ -291,8 +320,16 @@ def test_supervise_handles_deferred_task(
instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0)
time_machine.move_to(instant, tick=False)
- # Assert supervisor runs the task successfully
- assert supervise(ti=ti, dag_path=dagfile_path, token="", client=mock_client) == 0, captured_logs
+ bundle_info = BundleInfo.model_construct(name="my-bundle", version=None)
+ with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)):
+ exit_code = supervise(
+ ti=ti,
+ dag_rel_path="super_basic_deferred_run.py",
+ token="",
+ client=mock_client,
+ bundle_info=bundle_info,
+ )
+ assert exit_code == 0, captured_logs
# Validate calls to the client
mock_client.task_instances.start.assert_called_once_with(ti.id, mocker.ANY, mocker.ANY)
@@ -341,7 +378,7 @@ def handle_request(request: httpx.Request) -> httpx.Response:
client = make_client(transport=httpx.MockTransport(handle_request))
with pytest.raises(ServerResponseError, match="Server returned error") as err:
- ActivitySubprocess.start(path=os.devnull, what=ti, client=client)
+ ActivitySubprocess.start(dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=ti, client=client)
assert err.value.response.status_code == 409
assert err.value.detail == {
@@ -395,10 +432,11 @@ def handle_request(request: httpx.Request) -> httpx.Response:
return httpx.Response(status_code=204)
proc = ActivitySubprocess.start(
- path=os.devnull,
+ dag_rel_path=os.devnull,
what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
+ bundle_info=FAKE_BUNDLE,
)
# Wait for the subprocess to finish -- it should have been terminated
@@ -666,7 +704,8 @@ def _handler(sig, frame):
ti_id = uuid7()
proc = ActivitySubprocess.start(
- path=os.devnull,
+ dag_rel_path=os.devnull,
+ bundle_info=FAKE_BUNDLE,
what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1),
client=MagicMock(spec=sdk_client.Client),
target=subprocess_main,
diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py
index ff0cbff631772..e062c0ef33d81 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -17,11 +17,14 @@
from __future__ import annotations
+import json
+import os
import uuid
from datetime import timedelta
from pathlib import Path
from socket import socketpair
from unittest import mock
+from unittest.mock import patch
import pytest
from uuid6 import uuid7
@@ -37,6 +40,7 @@
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.execution_time.comms import (
+ BundleInfo,
ConnectionResult,
DeferTask,
GetConnection,
@@ -59,6 +63,8 @@
)
from airflow.utils import timezone
+FAKE_BUNDLE = BundleInfo.model_construct(name="anything", version="any")
+
def get_inline_dag(dag_id: str, task: BaseOperator) -> DAG:
"""Creates an inline dag and returns it based on dag_id and task."""
@@ -89,7 +95,8 @@ def test_recv_StartupDetails(self):
b'"ti_context":{"dag_run":{"dag_id":"c","run_id":"b","logical_date":"2024-12-01T01:00:00Z",'
b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",'
b'"start_date":"2024-12-01T01:00:00Z","end_date":null,"run_type":"manual","conf":null},'
- b'"max_tries":0,"variables":null,"connections":null},"file": "/dev/null", "requests_fd": '
+ b'"max_tries":0,"variables":null,"connections":null},"file": "/dev/null", "dag_rel_path": "/dev/null", "bundle_info": {"name": '
+ b'"any-name", "version": "any-version"}, "requests_fd": '
+ str(w2.fileno()).encode("ascii")
+ b"}\n"
)
@@ -101,7 +108,8 @@ def test_recv_StartupDetails(self):
assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab")
assert msg.ti.task_id == "a"
assert msg.ti.dag_id == "c"
- assert msg.file == "/dev/null"
+ assert msg.dag_rel_path == "/dev/null"
+ assert msg.bundle_info == BundleInfo.model_construct(name="any-name", version="any-version")
# Since this was a StartupDetails message, the decoder should open the other socket
assert decoder.request_socket is not None
@@ -113,12 +121,27 @@ def test_parse(test_dags_dir: Path, make_ti_context):
"""Test that checks parsing of a basic dag with an un-mocked parse."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1),
- file=str(test_dags_dir / "super_basic.py"),
+ dag_rel_path="super_basic.py",
+ bundle_info=BundleInfo.model_construct(name="my-bundle", version=None),
requests_fd=0,
ti_context=make_ti_context(),
)
- ti = parse(what)
+ with patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__DAG_BUNDLES__BACKENDS": json.dumps(
+ [
+ {
+ "name": "my-bundle",
+ "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle",
+ "kwargs": {"local_folder": str(test_dags_dir), "refresh_interval": 1},
+ }
+ ]
+ ),
+ },
+ ):
+ ti = parse(what)
assert ti.task
assert ti.task.dag
@@ -325,7 +348,8 @@ def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervi
ti=TaskInstance(
id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1
),
- file="",
+ bundle_info=FAKE_BUNDLE,
+ dag_rel_path="",
requests_fd=0,
ti_context=make_ti_context(),
)
@@ -393,7 +417,8 @@ def execute(self, context):
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1),
- file="",
+ dag_rel_path="",
+ bundle_info=FAKE_BUNDLE,
requests_fd=0,
ti_context=make_ti_context(),
)
diff --git a/tests/api_fastapi/common/test_exceptions.py b/tests/api_fastapi/common/test_exceptions.py
index e296f2320430c..6751aff20c725 100644
--- a/tests/api_fastapi/common/test_exceptions.py
+++ b/tests/api_fastapi/common/test_exceptions.py
@@ -186,7 +186,7 @@ def test_handle_single_column_unique_constraint_error(self, session, table, expe
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "Unique constraint violation",
- "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, (SELECT max(log_template.id) AS max_1 \nFROM log_template), ?, ?, ?, ?)",
+ "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, (SELECT max(log_template.id) AS max_1 \nFROM log_template), ?, ?, ?, ?, ?)",
"orig_error": "UNIQUE constraint failed: dag_run.dag_id, dag_run.run_id",
},
),
@@ -194,7 +194,7 @@ def test_handle_single_column_unique_constraint_error(self, session, table, expe
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "Unique constraint violation",
- "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %s, %s, %s, %s)",
+ "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %s, %s, %s, %s, %s)",
"orig_error": "(1062, \"Duplicate entry 'test_dag_id-test_run_id' for key 'dag_run.dag_run_dag_id_run_id_key'\")",
},
),
@@ -202,7 +202,7 @@ def test_handle_single_column_unique_constraint_error(self, session, table, expe
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "Unique constraint violation",
- "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id) VALUES (%(dag_id)s, %(queued_at)s, %(logical_date)s, %(start_date)s, %(end_date)s, %(state)s, %(run_id)s, %(creating_job_id)s, %(external_trigger)s, %(run_type)s, %(triggered_by)s, %(conf)s, %(data_interval_start)s, %(data_interval_end)s, %(last_scheduling_decision)s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %(updated_at)s, %(clear_number)s, %(backfill_id)s, %(dag_version_id)s) RETURNING dag_run.id",
+ "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version) VALUES (%(dag_id)s, %(queued_at)s, %(logical_date)s, %(start_date)s, %(end_date)s, %(state)s, %(run_id)s, %(creating_job_id)s, %(external_trigger)s, %(run_type)s, %(triggered_by)s, %(conf)s, %(data_interval_start)s, %(data_interval_end)s, %(last_scheduling_decision)s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %(updated_at)s, %(clear_number)s, %(backfill_id)s, %(dag_version_id)s, %(bundle_version)s) RETURNING dag_run.id",
"orig_error": 'duplicate key value violates unique constraint "dag_run_dag_id_run_id_key"\nDETAIL: Key (dag_id, run_id)=(test_dag_id, test_run_id) already exists.\n',
},
),
diff --git a/tests/cli/commands/remote_commands/test_task_command.py b/tests/cli/commands/remote_commands/test_task_command.py
index 9a4c606caa469..c1e6b6b23d7cd 100644
--- a/tests/cli/commands/remote_commands/test_task_command.py
+++ b/tests/cli/commands/remote_commands/test_task_command.py
@@ -491,56 +491,56 @@ def test_cli_run_no_local_no_raw_runs_executor(self, dag_maker):
from airflow.cli.commands.remote_commands import task_command
with dag_maker(dag_id="test_executor", schedule="@daily") as dag:
- with (
- mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock,
- mock.patch(
- "airflow.executors.executor_loader.ExecutorLoader.get_default_executor"
- ) as get_default_mock,
- mock.patch("airflow.executors.local_executor.SimpleQueue"), # Prevent a task being queued
- mock.patch("airflow.executors.local_executor.LocalExecutor.end"),
- ):
- EmptyOperator(task_id="task1")
- EmptyOperator(task_id="task2", executor="foo_executor_alias")
-
- dag_maker.create_dagrun()
-
- # Reload module to consume newly mocked executor loader
- reload(task_command)
-
- loader_mock.return_value = LocalExecutor()
- get_default_mock.return_value = LocalExecutor()
-
- # In the task1 case we will use the default executor
- task_command.task_run(
- self.parser.parse_args(
- [
- "tasks",
- "run",
- "test_executor",
- "task1",
- DEFAULT_DATE.isoformat(),
- ]
- ),
- dag,
- )
- get_default_mock.assert_called_once()
- loader_mock.assert_not_called()
-
- # In the task2 case we will use the executor configured on the task
- task_command.task_run(
- self.parser.parse_args(
- [
- "tasks",
- "run",
- "test_executor",
- "task2",
- DEFAULT_DATE.isoformat(),
- ]
- ),
- dag,
- )
- get_default_mock.assert_called_once() # Call from previous task
- loader_mock.assert_called_once_with("foo_executor_alias")
+ EmptyOperator(task_id="task1")
+ EmptyOperator(task_id="task2", executor="foo_executor_alias")
+
+ dag_maker.create_dagrun()
+
+ with (
+ mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock,
+ mock.patch(
+ "airflow.executors.executor_loader.ExecutorLoader.get_default_executor"
+ ) as get_default_mock,
+ mock.patch("airflow.executors.local_executor.SimpleQueue"), # Prevent a task being queued
+ mock.patch("airflow.executors.local_executor.LocalExecutor.end"),
+ ):
+ # Reload module to consume newly mocked executor loader
+ reload(task_command)
+
+ loader_mock.return_value = LocalExecutor()
+ get_default_mock.return_value = LocalExecutor()
+
+ # In the task1 case we will use the default executor
+ task_command.task_run(
+ self.parser.parse_args(
+ [
+ "tasks",
+ "run",
+ "test_executor",
+ "task1",
+ DEFAULT_DATE.isoformat(),
+ ]
+ ),
+ dag,
+ )
+ get_default_mock.assert_called_once()
+ loader_mock.assert_not_called()
+
+ # In the task2 case we will use the executor configured on the task
+ task_command.task_run(
+ self.parser.parse_args(
+ [
+ "tasks",
+ "run",
+ "test_executor",
+ "task2",
+ DEFAULT_DATE.isoformat(),
+ ]
+ ),
+ dag,
+ )
+ get_default_mock.assert_called_once() # Call from previous task
+ loader_mock.assert_called_once_with("foo_executor_alias")
# Reload module to remove mocked version of executor loader
reload(task_command)
diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py
index 59525a8829d39..673457509ca05 100644
--- a/tests/executors/test_local_executor.py
+++ b/tests/executors/test_local_executor.py
@@ -84,11 +84,23 @@ def fake_supervise(ti, **kwargs):
with spy_on(executor._spawn_worker) as spawn_worker:
for ti in success_tis:
executor.queue_workload(
- workloads.ExecuteTask(token="", ti=ti, dag_path="some/path", log_path=None)
+ workloads.ExecuteTask(
+ token="",
+ ti=ti,
+ dag_rel_path="some/path",
+ log_path=None,
+ bundle_info=dict(name="hi", version="hi"),
+ )
)
executor.queue_workload(
- workloads.ExecuteTask(token="", ti=fail_ti, dag_path="some/path", log_path=None)
+ workloads.ExecuteTask(
+ token="",
+ ti=fail_ti,
+ dag_rel_path="some/path",
+ log_path=None,
+ bundle_info=dict(name="hi", version="hi"),
+ )
)
executor.end()
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index a63b05b0cd343..95fa85848dcc9 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -2497,7 +2497,7 @@ def test_count_number_queries(self, tasks_count):
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
for i in range(tasks_count):
EmptyOperator(task_id=f"dummy_task_{i}", owner="test", dag=dag)
- with assert_queries_count(3):
+ with assert_queries_count(4):
dag.create_dagrun(
run_id="test_dagrun_query_count",
state=State.RUNNING,