From 6d048c43f0753d96976f3c9e72262cfe3b27d052 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 14 Jan 2025 16:19:26 -0800 Subject: [PATCH] Run the task with the configured dag bundle (#44752) Ensures that dag runs are created with a reference to the bundle that was in effect at the time. And when a dag run has bundle info, the task will be run with that dag bundle version. --- .../commands/remote_commands/task_command.py | 5 +- airflow/dag_processing/processor.py | 1 + airflow/executors/local_executor.py | 3 +- airflow/executors/workloads.py | 27 +- .../versions/0050_3_0_0_add_dagbundlemodel.py | 4 + airflow/models/dag.py | 8 + airflow/models/dagrun.py | 3 + docs/apache-airflow/img/airflow_erd.sha256 | 2 +- docs/apache-airflow/img/airflow_erd.svg | 1020 +++++++++-------- .../providers/edge/cli/edge_command.py | 3 +- providers/tests/edge/cli/test_edge_command.py | 3 +- .../edge/executors/test_edge_executor.py | 3 +- .../airflow/sdk/api/datamodels/_generated.py | 5 + .../src/airflow/sdk/execution_time/comms.py | 4 +- .../airflow/sdk/execution_time/supervisor.py | 31 +- .../airflow/sdk/execution_time/task_runner.py | 13 +- task_sdk/tests/execution_time/conftest.py | 5 +- .../tests/execution_time/test_supervisor.py | 67 +- .../tests/execution_time/test_task_runner.py | 37 +- tests/api_fastapi/common/test_exceptions.py | 6 +- .../remote_commands/test_task_command.py | 100 +- tests/executors/test_local_executor.py | 16 +- tests/models/test_dag.py | 2 +- 23 files changed, 751 insertions(+), 617 deletions(-) diff --git a/airflow/cli/commands/remote_commands/task_command.py b/airflow/cli/commands/remote_commands/task_command.py index 6c591801515d5..2329ee25bccaa 100644 --- a/airflow/cli/commands/remote_commands/task_command.py +++ b/airflow/cli/commands/remote_commands/task_command.py @@ -245,6 +245,9 @@ def _get_ti( # we do refresh_from_task so that if TI has come back via RPC, we ensure that ti.task # is the original task object and not the result of the round trip ti.refresh_from_task(task, pool_override=pool) + + ti.dag_model # we must ensure dag model is loaded eagerly for bundle info + return ti, dr_created @@ -286,7 +289,7 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None: if executor.queue_workload.__func__ is not BaseExecutor.queue_workload: # type: ignore[attr-defined] from airflow.executors import workloads - workload = workloads.ExecuteTask.make(ti, dag_path=dag.relative_fileloc) + workload = workloads.ExecuteTask.make(ti, dag_rel_path=dag.relative_fileloc) with create_session() as session: executor.queue_workload(workload, session) else: diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index c175f3f68c726..5d73c50e7d736 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -203,6 +203,7 @@ class DagFileProcessorProcess(WatchedSubprocess): @classmethod def start( # type: ignore[override] cls, + *, path: str | os.PathLike[str], callbacks: list[CallbackRequest], target: Callable[[], None] = _parse_file_entrypoint, diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 9f4f9c617c5f1..62d848cf66976 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -116,7 +116,8 @@ def _execute_work(log: logging.Logger, workload: workloads.ExecuteTask) -> None: supervise( # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. ti=workload.ti, # type: ignore[arg-type] - dag_path=workload.dag_path, + dag_rel_path=workload.dag_rel_path, + bundle_info=workload.bundle_info, token=workload.token, server=conf.get("workers", "execution_api_server_url", fallback="http://localhost:9091/execution/"), log_path=workload.log_path, diff --git a/airflow/executors/workloads.py b/airflow/executors/workloads.py index 13331f9b5793a..9a5e425ef887d 100644 --- a/airflow/executors/workloads.py +++ b/airflow/executors/workloads.py @@ -39,6 +39,13 @@ class BaseActivity(BaseModel): """The identity token for this workload""" +class BundleInfo(BaseModel): + """Schema for telling task which bundle to run with.""" + + name: str + version: str | None = None + + class TaskInstance(BaseModel): """Schema for TaskInstance with minimal required fields needed for Executors and Task SDK.""" @@ -73,30 +80,30 @@ class ExecuteTask(BaseActivity): ti: TaskInstance """The TaskInstance to execute""" - dag_path: os.PathLike[str] + dag_rel_path: os.PathLike[str] """The filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`)""" + bundle_info: BundleInfo + log_path: str | None """The rendered relative log filename template the task logs should be written to""" kind: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask") @classmethod - def make(cls, ti: TIModel, dag_path: Path | None = None) -> ExecuteTask: + def make(cls, ti: TIModel, dag_rel_path: Path | None = None) -> ExecuteTask: from pathlib import Path from airflow.utils.helpers import log_filename_template_renderer ser_ti = TaskInstance.model_validate(ti, from_attributes=True) - - dag_path = dag_path or Path(ti.dag_run.dag_model.relative_fileloc) - - if dag_path and not dag_path.is_absolute(): - # TODO: What about multiple dag sub folders - dag_path = "DAGS_FOLDER" / dag_path - + bundle_info = BundleInfo.model_construct( + name=ti.dag_model.bundle_name, + version=ti.dag_run.bundle_version, + ) + path = dag_rel_path or Path(ti.dag_run.dag_model.relative_fileloc) fname = log_filename_template_renderer()(ti=ti) - return cls(ti=ser_ti, dag_path=dag_path, token="", log_path=fname) + return cls(ti=ser_ti, dag_rel_path=path, token="", log_path=fname, bundle_info=bundle_info) All = Union[ExecuteTask] diff --git a/airflow/migrations/versions/0050_3_0_0_add_dagbundlemodel.py b/airflow/migrations/versions/0050_3_0_0_add_dagbundlemodel.py index 7e6cb756c5cb2..5322be0abe763 100644 --- a/airflow/migrations/versions/0050_3_0_0_add_dagbundlemodel.py +++ b/airflow/migrations/versions/0050_3_0_0_add_dagbundlemodel.py @@ -53,6 +53,8 @@ def upgrade(): batch_op.create_foreign_key( batch_op.f("dag_bundle_name_fkey"), "dag_bundle", ["bundle_name"], ["name"] ) + with op.batch_alter_table("dag_run", schema=None) as batch_op: + batch_op.add_column(sa.Column("bundle_version", sa.String(length=250), nullable=True)) def downgrade(): @@ -60,5 +62,7 @@ def downgrade(): batch_op.drop_constraint(batch_op.f("dag_bundle_name_fkey"), type_="foreignkey") batch_op.drop_column("latest_bundle_version") batch_op.drop_column("bundle_name") + with op.batch_alter_table("dag_run", schema=None) as batch_op: + batch_op.drop_column("bundle_version") op.drop_table("dag_bundle") diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 82b4ca70819b9..ffd3d2dba56cc 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -262,6 +262,13 @@ def _create_orm_dagrun( session, triggered_by, ): + bundle_version = session.scalar( + select( + DagModel.latest_bundle_version, + ).where( + DagModel.dag_id == dag.dag_id, + ) + ) run = DagRun( dag_id=dag_id, run_id=run_id, @@ -276,6 +283,7 @@ def _create_orm_dagrun( data_interval=data_interval, triggered_by=triggered_by, backfill_id=backfill_id, + bundle_version=bundle_version, ) # Load defaults into the following two fields to ensure result can be serialized detached run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id)))) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 15d275da1a465..e7aad18672d5c 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -166,6 +166,7 @@ class DagRun(Base, LoggingMixin): """ dag_version_id = Column(UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE")) dag_version = relationship("DagVersion", back_populates="dag_runs") + bundle_version = Column(StringID()) # Remove this `if` after upgrading Sphinx-AutoAPI if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ: @@ -238,6 +239,7 @@ def __init__( triggered_by: DagRunTriggeredByType | None = None, backfill_id: int | None = None, dag_version: DagVersion | None = None, + bundle_version: str | None = None, ): if data_interval is None: # Legacy: Only happen for runs created prior to Airflow 2.2. @@ -245,6 +247,7 @@ def __init__( else: self.data_interval_start, self.data_interval_end = data_interval + self.bundle_version = bundle_version self.dag_id = dag_id self.run_id = run_id self.logical_date = logical_date diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index 3616222fd880c..86afe05d37001 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -ca59d711e6304f8bfdb25f49339d455602430dd6b880e420869fc892faef0596 \ No newline at end of file +00d5d138d0773a6b700ada4650f5c60cc3972afefd3945ea434dea50abfda834 \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index 24f75b3247093..9fe57986ed3b3 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -25,250 +25,250 @@ [VARCHAR(2000)] NOT NULL - + -slot_pool - -slot_pool - -id - - [INTEGER] - NOT NULL - -description - - [TEXT] - -include_deferred - - [BOOLEAN] - NOT NULL - -pool - - [VARCHAR(256)] - -slots - - [INTEGER] +log + +log + +id + + [INTEGER] + NOT NULL + +dag_id + + [VARCHAR(250)] + +dttm + + [TIMESTAMP] + +event + + [VARCHAR(60)] + +extra + + [TEXT] + +logical_date + + [TIMESTAMP] + +map_index + + [INTEGER] + +owner + + [VARCHAR(500)] + +owner_display_name + + [VARCHAR(500)] + +run_id + + [VARCHAR(250)] + +task_id + + [VARCHAR(250)] + +try_number + + [INTEGER] - + -callback_request - -callback_request - -id - - [INTEGER] - NOT NULL - -callback_data - - [JSON] - NOT NULL - -callback_type - - [VARCHAR(20)] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL - -priority_weight - - [INTEGER] - NOT NULL +slot_pool + +slot_pool + +id + + [INTEGER] + NOT NULL + +description + + [TEXT] + +include_deferred + + [BOOLEAN] + NOT NULL + +pool + + [VARCHAR(256)] + +slots + + [INTEGER] - + -log - -log - -id - - [INTEGER] - NOT NULL - -dag_id - - [VARCHAR(250)] - -dttm - - [TIMESTAMP] - -event - - [VARCHAR(60)] - -extra - - [TEXT] - -logical_date - - [TIMESTAMP] - -map_index - - [INTEGER] - -owner - - [VARCHAR(500)] - -owner_display_name - - [VARCHAR(500)] - -run_id - - [VARCHAR(250)] - -task_id - - [VARCHAR(250)] - -try_number - - [INTEGER] - - - -job - -job - -id - - [INTEGER] - NOT NULL - -dag_id - - [VARCHAR(250)] - -end_date - - [TIMESTAMP] - -executor_class - - [VARCHAR(500)] - -hostname - - [VARCHAR(500)] - -job_type - - [VARCHAR(30)] - -latest_heartbeat - - [TIMESTAMP] - -start_date - - [TIMESTAMP] - -state - - [VARCHAR(20)] - -unixname - - [VARCHAR(1000)] +callback_request + +callback_request + +id + + [INTEGER] + NOT NULL + +callback_data + + [JSON] + NOT NULL + +callback_type + + [VARCHAR(20)] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL + +priority_weight + + [INTEGER] + NOT NULL - + connection - -connection - -id - - [INTEGER] - NOT NULL - -conn_id - - [VARCHAR(250)] - NOT NULL - -conn_type - - [VARCHAR(500)] - NOT NULL - -description - - [TEXT] - -extra - - [TEXT] - -host - - [VARCHAR(500)] - -is_encrypted - - [BOOLEAN] - -is_extra_encrypted - - [BOOLEAN] - -login - - [TEXT] - -password - - [TEXT] - -port - - [INTEGER] - -schema - - [VARCHAR(500)] + +connection + +id + + [INTEGER] + NOT NULL + +conn_id + + [VARCHAR(250)] + NOT NULL + +conn_type + + [VARCHAR(500)] + NOT NULL + +description + + [TEXT] + +extra + + [TEXT] + +host + + [VARCHAR(500)] + +is_encrypted + + [BOOLEAN] + +is_extra_encrypted + + [BOOLEAN] + +login + + [TEXT] + +password + + [TEXT] + +port + + [INTEGER] + +schema + + [VARCHAR(500)] - + variable - -variable - -id - - [INTEGER] - NOT NULL - -description - - [TEXT] - -is_encrypted - - [BOOLEAN] - -key - - [VARCHAR(250)] - -val - - [TEXT] + +variable + +id + + [INTEGER] + NOT NULL + +description + + [TEXT] + +is_encrypted + + [BOOLEAN] + +key + + [VARCHAR(250)] + +val + + [TEXT] + + + +job + +job + +id + + [INTEGER] + NOT NULL + +dag_id + + [VARCHAR(250)] + +end_date + + [TIMESTAMP] + +executor_class + + [VARCHAR(500)] + +hostname + + [VARCHAR(500)] + +job_type + + [VARCHAR(30)] + +latest_heartbeat + + [TIMESTAMP] + +start_date + + [TIMESTAMP] + +state + + [VARCHAR(20)] + +unixname + + [VARCHAR(1000)] @@ -1043,30 +1043,30 @@ task_instance--task_map - + 0..N -1 +1 task_instance--task_map - + 0..N -1 +1 task_instance--task_map - + 0..N -1 +1 task_instance--task_map - + 0..N -1 +1 @@ -1116,30 +1116,30 @@ task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 @@ -1727,41 +1727,41 @@ deadline - -deadline - -id - - [UUID] - NOT NULL - -callback - - [VARCHAR(500)] - NOT NULL - -callback_kwargs - - [JSON] - -dag_id - - [VARCHAR(250)] - -dagrun_id - - [INTEGER] - -deadline - - [TIMESTAMP] - NOT NULL + +deadline + +id + + [UUID] + NOT NULL + +callback + + [VARCHAR(500)] + NOT NULL + +callback_kwargs + + [JSON] + +dag_id + + [VARCHAR(250)] + +dagrun_id + + [INTEGER] + +deadline + + [TIMESTAMP] + NOT NULL dag--deadline - -0..N + +0..N {0,1} @@ -1774,104 +1774,108 @@ dag_run - -dag_run - -id - - [INTEGER] - NOT NULL - -backfill_id - - [INTEGER] - -clear_number - - [INTEGER] - NOT NULL - -conf - - [JSONB] - -creating_job_id - - [INTEGER] - -dag_id - - [VARCHAR(250)] - NOT NULL - -dag_version_id - - [UUID] - -data_interval_end - - [TIMESTAMP] - -data_interval_start - - [TIMESTAMP] - -end_date - - [TIMESTAMP] - -external_trigger - - [BOOLEAN] - -last_scheduling_decision - - [TIMESTAMP] - -log_template_id - - [INTEGER] - -logical_date - - [TIMESTAMP] - NOT NULL - -queued_at - - [TIMESTAMP] - -run_id - - [VARCHAR(250)] - NOT NULL - -run_type - - [VARCHAR(50)] - NOT NULL - -start_date - - [TIMESTAMP] - -state - - [VARCHAR(50)] - -triggered_by - - [VARCHAR(50)] - -updated_at - - [TIMESTAMP] + +dag_run + +id + + [INTEGER] + NOT NULL + +backfill_id + + [INTEGER] + +bundle_version + + [VARCHAR(250)] + +clear_number + + [INTEGER] + NOT NULL + +conf + + [JSONB] + +creating_job_id + + [INTEGER] + +dag_id + + [VARCHAR(250)] + NOT NULL + +dag_version_id + + [UUID] + +data_interval_end + + [TIMESTAMP] + +data_interval_start + + [TIMESTAMP] + +end_date + + [TIMESTAMP] + +external_trigger + + [BOOLEAN] + +last_scheduling_decision + + [TIMESTAMP] + +log_template_id + + [INTEGER] + +logical_date + + [TIMESTAMP] + NOT NULL + +queued_at + + [TIMESTAMP] + +run_id + + [VARCHAR(250)] + NOT NULL + +run_type + + [VARCHAR(50)] + NOT NULL + +start_date + + [TIMESTAMP] + +state + + [VARCHAR(50)] + +triggered_by + + [VARCHAR(50)] + +updated_at + + [TIMESTAMP] dag_version--dag_run - -0..N + +0..N {0,1} @@ -1971,121 +1975,121 @@ dag_run--dagrun_asset_event - -0..N -1 + +0..N +1 dag_run--task_instance - -0..N -1 + +0..N +1 dag_run--task_instance - -0..N -1 + +0..N +1 dag_run--deadline - -0..N -{0,1} + +0..N +{0,1} backfill_dag_run - -backfill_dag_run - -id - - [INTEGER] - NOT NULL - -backfill_id - - [INTEGER] - NOT NULL - -dag_run_id - - [INTEGER] - -exception_reason - - [VARCHAR(250)] - -logical_date - - [TIMESTAMP] - NOT NULL - -sort_ordinal - - [INTEGER] - NOT NULL + +backfill_dag_run + +id + + [INTEGER] + NOT NULL + +backfill_id + + [INTEGER] + NOT NULL + +dag_run_id + + [INTEGER] + +exception_reason + + [VARCHAR(250)] + +logical_date + + [TIMESTAMP] + NOT NULL + +sort_ordinal + + [INTEGER] + NOT NULL dag_run--backfill_dag_run - -0..N -{0,1} + +0..N +{0,1} dag_run_note - -dag_run_note - -dag_run_id - - [INTEGER] - NOT NULL - -content - - [VARCHAR(1000)] - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -user_id - - [VARCHAR(128)] + +dag_run_note + +dag_run_id + + [INTEGER] + NOT NULL + +content + + [VARCHAR(1000)] + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL + +user_id + + [VARCHAR(128)] dag_run--dag_run_note - -1 -1 + +1 +1 dag_run--task_reschedule - -0..N -1 + +0..N +1 dag_run--task_reschedule - -0..N -1 + +0..N +1 @@ -2116,9 +2120,9 @@ log_template--dag_run - -0..N -{0,1} + +0..N +{0,1} @@ -2182,16 +2186,16 @@ backfill--dag_run - -0..N -{0,1} + +0..N +{0,1} backfill--backfill_dag_run - -0..N -1 + +0..N +1 @@ -2311,28 +2315,28 @@ ab_user_role - -ab_user_role - -id - - [INTEGER] - NOT NULL - -role_id - - [INTEGER] - -user_id - - [INTEGER] + +ab_user_role + +id + + [INTEGER] + NOT NULL + +role_id + + [INTEGER] + +user_id + + [INTEGER] ab_user--ab_user_role - -0..N -{0,1} + +0..N +{0,1} @@ -2422,28 +2426,28 @@ ab_permission_view_role - -ab_permission_view_role - -id - - [INTEGER] - NOT NULL - -permission_view_id - - [INTEGER] - -role_id - - [INTEGER] + +ab_permission_view_role + +id + + [INTEGER] + NOT NULL + +permission_view_id + + [INTEGER] + +role_id + + [INTEGER] ab_permission_view--ab_permission_view_role - -0..N -{0,1} + +0..N +{0,1} @@ -2487,16 +2491,16 @@ ab_role--ab_user_role - -0..N -{0,1} + +0..N +{0,1} ab_role--ab_permission_view_role - -0..N -{0,1} + +0..N +{0,1} diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py b/providers/src/airflow/providers/edge/cli/edge_command.py index 82fedc88e0714..d0dbd66b241c4 100644 --- a/providers/src/airflow/providers/edge/cli/edge_command.py +++ b/providers/src/airflow/providers/edge/cli/edge_command.py @@ -223,7 +223,8 @@ def _run_job_via_supervisor( # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. # Same like in airflow/executors/local_executor.py:_execute_work() ti=workload.ti, # type: ignore[arg-type] - dag_path=workload.dag_path, + dag_rel_path=workload.dag_rel_path, + bundle_info=workload.bundle_info, token=workload.token, server=conf.get( "workers", "execution_api_server_url", fallback="http://localhost:9091/execution/" diff --git a/providers/tests/edge/cli/test_edge_command.py b/providers/tests/edge/cli/test_edge_command.py index e4161901dbca6..b1b719444baf0 100644 --- a/providers/tests/edge/cli/test_edge_command.py +++ b/providers/tests/edge/cli/test_edge_command.py @@ -52,8 +52,9 @@ "queue": "default", "priority_weight": 1, }, - "dag_path": "dummy.py", + "dag_rel_path": "dummy.py", "log_path": "dummy.log", + "bundle_info": {"name": "hello", "version": "abc"}, } if AIRFLOW_V_3_0_PLUS else ["test", "command"] # Airflow 2.10 diff --git a/providers/tests/edge/executors/test_edge_executor.py b/providers/tests/edge/executors/test_edge_executor.py index 2d22744b5f1a5..3a5e6b18d69a3 100644 --- a/providers/tests/edge/executors/test_edge_executor.py +++ b/providers/tests/edge/executors/test_edge_executor.py @@ -301,8 +301,9 @@ def test_queue_workload(self): queue="default", priority_weight=1, ), - dag_path="dummy.py", + dag_rel_path="dummy.py", log_path="dummy.log", + bundle_info={"name": "n/a", "version": "no matter"}, ) executor.queue_workload(workload=workload) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 785037b6bfbaa..a8b478d07f029 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -168,6 +168,11 @@ class XComResponse(BaseModel): value: Annotated[Any, Field(title="Value")] +class BundleInfo(BaseModel): + name: str + version: str | None = None + + class TaskInstance(BaseModel): """ Schema for TaskInstance model with minimal required fields needed for Runtime. diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index e1007d1f2fc05..b6874d47f090c 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -50,6 +50,7 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue from airflow.sdk.api.datamodels._generated import ( + BundleInfo, ConnectionResponse, TaskInstance, TerminalTIState, @@ -66,7 +67,8 @@ class StartupDetails(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) ti: TaskInstance - file: str + dag_rel_path: str + bundle_info: BundleInfo requests_fd: int """ The channel for the task to send requests over. diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 002b156cee2db..32895d36524d8 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -80,6 +80,7 @@ if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger, WrappedLogger + from airflow.executors.workloads import BundleInfo from airflow.typing_compat import Self @@ -316,6 +317,7 @@ class WatchedSubprocess: @classmethod def start( cls, + *, target: Callable[[], None] = _subprocess_main, logger: FilteringBoundLogger | None = None, **constructor_kwargs, @@ -574,8 +576,10 @@ class ActivitySubprocess(WatchedSubprocess): @classmethod def start( # type: ignore[override] cls, - path: str | os.PathLike[str], + *, what: TaskInstance, + dag_rel_path: str | os.PathLike[str], + bundle_info, client: Client, target: Callable[[], None] = _subprocess_main, logger: FilteringBoundLogger | None = None, @@ -584,10 +588,10 @@ def start( # type: ignore[override] """Fork and start a new subprocess to execute the given task.""" proc: Self = super().start(id=what.id, client=client, target=target, logger=logger, **kwargs) # Tell the task process what it needs to do! - proc._on_child_started(what, path) + proc._on_child_started(ti=what, dag_rel_path=dag_rel_path, bundle_info=bundle_info) return proc - def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str]): + def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[str], bundle_info): """Send startup message to the subprocess.""" try: # We've forked, but the task won't start doing anything until we send it the StartupDetails @@ -602,7 +606,8 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str]): msg = StartupDetails.model_construct( ti=ti, - file=os.fspath(path), + dag_rel_path=os.fspath(dag_rel_path), + bundle_info=bundle_info, requests_fd=self._requests_fd, ti_context=ti_context, ) @@ -879,7 +884,8 @@ def forward_to_log(target_log: FilteringBoundLogger, level: int) -> Generator[No def supervise( *, ti: TaskInstance, - dag_path: str | os.PathLike[str], + bundle_info: BundleInfo, + dag_rel_path: str | os.PathLike[str], token: str, server: str | None = None, dry_run: bool = False, @@ -902,14 +908,9 @@ def supervise( if not client and ((not server) ^ dry_run): raise ValueError(f"Can only specify one of {server=} or {dry_run=}") - if not dag_path: + if not dag_rel_path: raise ValueError("dag_path is required") - if (str_path := os.fspath(dag_path)).startswith("DAGS_FOLDER/"): - from airflow.settings import DAGS_FOLDER - - dag_path = str_path.replace("DAGS_FOLDER/", DAGS_FOLDER + "/", 1) - if not client: limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token) @@ -932,7 +933,13 @@ def supervise( processors = logging_processors(enable_pretty_log=pretty_logs)[0] logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind() - process = ActivitySubprocess.start(dag_path, ti, client=client, logger=logger) + process = ActivitySubprocess.start( + dag_rel_path=dag_rel_path, + what=ti, + client=client, + logger=logger, + bundle_info=bundle_info, + ) exit_code = process.wait() end = time.monotonic() diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 2beff84e8f8d7..796c9f57d104a 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -24,12 +24,14 @@ from collections.abc import Iterable, Mapping from datetime import datetime, timezone from io import FileIO +from pathlib import Path from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar import attrs import structlog from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState, TIRunContext from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.execution_time.comms import ( @@ -301,8 +303,15 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance: from airflow.models.dagbag import DagBag + bundle_info = what.bundle_info + bundle_instance = DagBundlesManager().get_bundle( + name=bundle_info.name, + version=bundle_info.version, + ) + + dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) bag = DagBag( - dag_folder=what.file, + dag_folder=dag_absolute_path, include_examples=False, safe_mode=False, load_op_links=False, @@ -399,7 +408,7 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]: log = structlog.get_logger(logger_name="task") # TODO: set the "magic loop" context vars for parsing ti = parse(msg) - log.debug("DAG file parsed", file=msg.file) + log.debug("DAG file parsed", file=msg.dag_rel_path) else: raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py index 507d17aa3508e..641f14817d899 100644 --- a/task_sdk/tests/execution_time/conftest.py +++ b/task_sdk/tests/execution_time/conftest.py @@ -117,7 +117,7 @@ def execute(self, context): from uuid6 import uuid7 from airflow.sdk.api.datamodels._generated import TaskInstance - from airflow.sdk.execution_time.comms import StartupDetails + from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails def _create_task_instance( task: BaseOperator, @@ -148,7 +148,8 @@ def _create_task_instance( ti=TaskInstance( id=ti_id, task_id=task.task_id, dag_id=dag_id, run_id=run_id, try_number=try_number ), - file="", + dag_rel_path="", + bundle_info=BundleInfo.model_construct(name="anything", version="any"), requests_fd=0, ti_context=ti_context, ) diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index c5879971a63d4..12c3455ccfe1e 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -18,6 +18,7 @@ from __future__ import annotations import inspect +import json import logging import os import selectors @@ -27,7 +28,7 @@ from operator import attrgetter from time import sleep from typing import TYPE_CHECKING -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import httpx import psutil @@ -35,6 +36,7 @@ from pytest_unordered import unordered from uuid6 import uuid7 +from airflow.executors.workloads import BundleInfo from airflow.sdk.api import client as sdk_client from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState @@ -57,6 +59,7 @@ from airflow.utils import timezone, timezone as tz from task_sdk.tests.api.test_client import make_client +from task_sdk.tests.execution_time.test_task_runner import FAKE_BUNDLE if TYPE_CHECKING: import kgb @@ -69,6 +72,20 @@ def lineno(): return inspect.currentframe().f_back.f_lineno +def local_dag_bundle_cfg(path, name="my-bundle"): + return { + "AIRFLOW__DAG_BUNDLES__BACKENDS": json.dumps( + [ + { + "name": name, + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": str(path), "refresh_interval": 1}, + } + ] + ) + } + + @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: def test_reading_from_pipes(self, captured_logs, time_machine): @@ -100,7 +117,8 @@ def subprocess_main(): time_machine.move_to(instant, tick=False) proc = ActivitySubprocess.start( - path=os.devnull, + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, what=TaskInstance( id="4d828a62-a417-4936-a7a6-2b3fabacecab", task_id="b", @@ -167,7 +185,8 @@ def subprocess_main(): os.kill(os.getpid(), signal.SIGKILL) proc = ActivitySubprocess.start( - path=os.devnull, + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, what=TaskInstance( id="4d828a62-a417-4936-a7a6-2b3fabacecab", task_id="b", @@ -190,7 +209,8 @@ def subprocess_main(): raise RuntimeError("Fake syntax error") proc = ActivitySubprocess.start( - path=os.devnull, + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, what=TaskInstance( id=uuid7(), task_id="b", @@ -226,7 +246,8 @@ def subprocess_main(): ti_id = uuid7() spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat) proc = ActivitySubprocess.start( - path=os.devnull, + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, what=TaskInstance( id=ti_id, task_id="b", @@ -248,7 +269,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901) time_machine.move_to(instant, tick=False) - dagfile_path = test_dags_dir / "super_basic_run.py" + dagfile_path = test_dags_dir ti = TaskInstance( id=uuid7(), task_id="hello", @@ -256,8 +277,17 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): run_id="c", try_number=1, ) - # Assert Exit Code is 0 - assert supervise(ti=ti, dag_path=dagfile_path, token="", server="", dry_run=True) == 0, captured_logs + bundle_info = BundleInfo.model_construct(name="my-bundle", version=None) + with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): + exit_code = supervise( + ti=ti, + dag_rel_path=dagfile_path, + token="", + server="", + dry_run=True, + bundle_info=bundle_info, + ) + assert exit_code == 0, captured_logs # We should have a log from the task! assert { @@ -281,7 +311,6 @@ def test_supervise_handles_deferred_task( ti = TaskInstance( id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="d", try_number=1 ) - dagfile_path = test_dags_dir / "super_basic_deferred_run.py" # Create a mock client to assert calls to the client # We assume the implementation of the client is correct and only need to check the calls @@ -291,8 +320,16 @@ def test_supervise_handles_deferred_task( instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0) time_machine.move_to(instant, tick=False) - # Assert supervisor runs the task successfully - assert supervise(ti=ti, dag_path=dagfile_path, token="", client=mock_client) == 0, captured_logs + bundle_info = BundleInfo.model_construct(name="my-bundle", version=None) + with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): + exit_code = supervise( + ti=ti, + dag_rel_path="super_basic_deferred_run.py", + token="", + client=mock_client, + bundle_info=bundle_info, + ) + assert exit_code == 0, captured_logs # Validate calls to the client mock_client.task_instances.start.assert_called_once_with(ti.id, mocker.ANY, mocker.ANY) @@ -341,7 +378,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: client = make_client(transport=httpx.MockTransport(handle_request)) with pytest.raises(ServerResponseError, match="Server returned error") as err: - ActivitySubprocess.start(path=os.devnull, what=ti, client=client) + ActivitySubprocess.start(dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=ti, client=client) assert err.value.response.status_code == 409 assert err.value.detail == { @@ -395,10 +432,11 @@ def handle_request(request: httpx.Request) -> httpx.Response: return httpx.Response(status_code=204) proc = ActivitySubprocess.start( - path=os.devnull, + dag_rel_path=os.devnull, what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), client=make_client(transport=httpx.MockTransport(handle_request)), target=subprocess_main, + bundle_info=FAKE_BUNDLE, ) # Wait for the subprocess to finish -- it should have been terminated @@ -666,7 +704,8 @@ def _handler(sig, frame): ti_id = uuid7() proc = ActivitySubprocess.start( - path=os.devnull, + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), client=MagicMock(spec=sdk_client.Client), target=subprocess_main, diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index ff0cbff631772..e062c0ef33d81 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -17,11 +17,14 @@ from __future__ import annotations +import json +import os import uuid from datetime import timedelta from pathlib import Path from socket import socketpair from unittest import mock +from unittest.mock import patch import pytest from uuid6 import uuid7 @@ -37,6 +40,7 @@ from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.definitions.variable import Variable from airflow.sdk.execution_time.comms import ( + BundleInfo, ConnectionResult, DeferTask, GetConnection, @@ -59,6 +63,8 @@ ) from airflow.utils import timezone +FAKE_BUNDLE = BundleInfo.model_construct(name="anything", version="any") + def get_inline_dag(dag_id: str, task: BaseOperator) -> DAG: """Creates an inline dag and returns it based on dag_id and task.""" @@ -89,7 +95,8 @@ def test_recv_StartupDetails(self): b'"ti_context":{"dag_run":{"dag_id":"c","run_id":"b","logical_date":"2024-12-01T01:00:00Z",' b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",' b'"start_date":"2024-12-01T01:00:00Z","end_date":null,"run_type":"manual","conf":null},' - b'"max_tries":0,"variables":null,"connections":null},"file": "/dev/null", "requests_fd": ' + b'"max_tries":0,"variables":null,"connections":null},"file": "/dev/null", "dag_rel_path": "/dev/null", "bundle_info": {"name": ' + b'"any-name", "version": "any-version"}, "requests_fd": ' + str(w2.fileno()).encode("ascii") + b"}\n" ) @@ -101,7 +108,8 @@ def test_recv_StartupDetails(self): assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") assert msg.ti.task_id == "a" assert msg.ti.dag_id == "c" - assert msg.file == "/dev/null" + assert msg.dag_rel_path == "/dev/null" + assert msg.bundle_info == BundleInfo.model_construct(name="any-name", version="any-version") # Since this was a StartupDetails message, the decoder should open the other socket assert decoder.request_socket is not None @@ -113,12 +121,27 @@ def test_parse(test_dags_dir: Path, make_ti_context): """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1), - file=str(test_dags_dir / "super_basic.py"), + dag_rel_path="super_basic.py", + bundle_info=BundleInfo.model_construct(name="my-bundle", version=None), requests_fd=0, ti_context=make_ti_context(), ) - ti = parse(what) + with patch.dict( + os.environ, + { + "AIRFLOW__DAG_BUNDLES__BACKENDS": json.dumps( + [ + { + "name": "my-bundle", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": str(test_dags_dir), "refresh_interval": 1}, + } + ] + ), + }, + ): + ti = parse(what) assert ti.task assert ti.task.dag @@ -325,7 +348,8 @@ def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervi ti=TaskInstance( id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1 ), - file="", + bundle_info=FAKE_BUNDLE, + dag_rel_path="", requests_fd=0, ti_context=make_ti_context(), ) @@ -393,7 +417,8 @@ def execute(self, context): what = StartupDetails( ti=TaskInstance(id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1), - file="", + dag_rel_path="", + bundle_info=FAKE_BUNDLE, requests_fd=0, ti_context=make_ti_context(), ) diff --git a/tests/api_fastapi/common/test_exceptions.py b/tests/api_fastapi/common/test_exceptions.py index e296f2320430c..6751aff20c725 100644 --- a/tests/api_fastapi/common/test_exceptions.py +++ b/tests/api_fastapi/common/test_exceptions.py @@ -186,7 +186,7 @@ def test_handle_single_column_unique_constraint_error(self, session, table, expe status_code=status.HTTP_409_CONFLICT, detail={ "reason": "Unique constraint violation", - "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, (SELECT max(log_template.id) AS max_1 \nFROM log_template), ?, ?, ?, ?)", + "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, (SELECT max(log_template.id) AS max_1 \nFROM log_template), ?, ?, ?, ?, ?)", "orig_error": "UNIQUE constraint failed: dag_run.dag_id, dag_run.run_id", }, ), @@ -194,7 +194,7 @@ def test_handle_single_column_unique_constraint_error(self, session, table, expe status_code=status.HTTP_409_CONFLICT, detail={ "reason": "Unique constraint violation", - "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %s, %s, %s, %s)", + "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %s, %s, %s, %s, %s)", "orig_error": "(1062, \"Duplicate entry 'test_dag_id-test_run_id' for key 'dag_run.dag_run_dag_id_run_id_key'\")", }, ), @@ -202,7 +202,7 @@ def test_handle_single_column_unique_constraint_error(self, session, table, expe status_code=status.HTTP_409_CONFLICT, detail={ "reason": "Unique constraint violation", - "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id) VALUES (%(dag_id)s, %(queued_at)s, %(logical_date)s, %(start_date)s, %(end_date)s, %(state)s, %(run_id)s, %(creating_job_id)s, %(external_trigger)s, %(run_type)s, %(triggered_by)s, %(conf)s, %(data_interval_start)s, %(data_interval_end)s, %(last_scheduling_decision)s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %(updated_at)s, %(clear_number)s, %(backfill_id)s, %(dag_version_id)s) RETURNING dag_run.id", + "statement": "INSERT INTO dag_run (dag_id, queued_at, logical_date, start_date, end_date, state, run_id, creating_job_id, external_trigger, run_type, triggered_by, conf, data_interval_start, data_interval_end, last_scheduling_decision, log_template_id, updated_at, clear_number, backfill_id, dag_version_id, bundle_version) VALUES (%(dag_id)s, %(queued_at)s, %(logical_date)s, %(start_date)s, %(end_date)s, %(state)s, %(run_id)s, %(creating_job_id)s, %(external_trigger)s, %(run_type)s, %(triggered_by)s, %(conf)s, %(data_interval_start)s, %(data_interval_end)s, %(last_scheduling_decision)s, (SELECT max(log_template.id) AS max_1 \nFROM log_template), %(updated_at)s, %(clear_number)s, %(backfill_id)s, %(dag_version_id)s, %(bundle_version)s) RETURNING dag_run.id", "orig_error": 'duplicate key value violates unique constraint "dag_run_dag_id_run_id_key"\nDETAIL: Key (dag_id, run_id)=(test_dag_id, test_run_id) already exists.\n', }, ), diff --git a/tests/cli/commands/remote_commands/test_task_command.py b/tests/cli/commands/remote_commands/test_task_command.py index 9a4c606caa469..c1e6b6b23d7cd 100644 --- a/tests/cli/commands/remote_commands/test_task_command.py +++ b/tests/cli/commands/remote_commands/test_task_command.py @@ -491,56 +491,56 @@ def test_cli_run_no_local_no_raw_runs_executor(self, dag_maker): from airflow.cli.commands.remote_commands import task_command with dag_maker(dag_id="test_executor", schedule="@daily") as dag: - with ( - mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock, - mock.patch( - "airflow.executors.executor_loader.ExecutorLoader.get_default_executor" - ) as get_default_mock, - mock.patch("airflow.executors.local_executor.SimpleQueue"), # Prevent a task being queued - mock.patch("airflow.executors.local_executor.LocalExecutor.end"), - ): - EmptyOperator(task_id="task1") - EmptyOperator(task_id="task2", executor="foo_executor_alias") - - dag_maker.create_dagrun() - - # Reload module to consume newly mocked executor loader - reload(task_command) - - loader_mock.return_value = LocalExecutor() - get_default_mock.return_value = LocalExecutor() - - # In the task1 case we will use the default executor - task_command.task_run( - self.parser.parse_args( - [ - "tasks", - "run", - "test_executor", - "task1", - DEFAULT_DATE.isoformat(), - ] - ), - dag, - ) - get_default_mock.assert_called_once() - loader_mock.assert_not_called() - - # In the task2 case we will use the executor configured on the task - task_command.task_run( - self.parser.parse_args( - [ - "tasks", - "run", - "test_executor", - "task2", - DEFAULT_DATE.isoformat(), - ] - ), - dag, - ) - get_default_mock.assert_called_once() # Call from previous task - loader_mock.assert_called_once_with("foo_executor_alias") + EmptyOperator(task_id="task1") + EmptyOperator(task_id="task2", executor="foo_executor_alias") + + dag_maker.create_dagrun() + + with ( + mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock, + mock.patch( + "airflow.executors.executor_loader.ExecutorLoader.get_default_executor" + ) as get_default_mock, + mock.patch("airflow.executors.local_executor.SimpleQueue"), # Prevent a task being queued + mock.patch("airflow.executors.local_executor.LocalExecutor.end"), + ): + # Reload module to consume newly mocked executor loader + reload(task_command) + + loader_mock.return_value = LocalExecutor() + get_default_mock.return_value = LocalExecutor() + + # In the task1 case we will use the default executor + task_command.task_run( + self.parser.parse_args( + [ + "tasks", + "run", + "test_executor", + "task1", + DEFAULT_DATE.isoformat(), + ] + ), + dag, + ) + get_default_mock.assert_called_once() + loader_mock.assert_not_called() + + # In the task2 case we will use the executor configured on the task + task_command.task_run( + self.parser.parse_args( + [ + "tasks", + "run", + "test_executor", + "task2", + DEFAULT_DATE.isoformat(), + ] + ), + dag, + ) + get_default_mock.assert_called_once() # Call from previous task + loader_mock.assert_called_once_with("foo_executor_alias") # Reload module to remove mocked version of executor loader reload(task_command) diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index 59525a8829d39..673457509ca05 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -84,11 +84,23 @@ def fake_supervise(ti, **kwargs): with spy_on(executor._spawn_worker) as spawn_worker: for ti in success_tis: executor.queue_workload( - workloads.ExecuteTask(token="", ti=ti, dag_path="some/path", log_path=None) + workloads.ExecuteTask( + token="", + ti=ti, + dag_rel_path="some/path", + log_path=None, + bundle_info=dict(name="hi", version="hi"), + ) ) executor.queue_workload( - workloads.ExecuteTask(token="", ti=fail_ti, dag_path="some/path", log_path=None) + workloads.ExecuteTask( + token="", + ti=fail_ti, + dag_rel_path="some/path", + log_path=None, + bundle_info=dict(name="hi", version="hi"), + ) ) executor.end() diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index a63b05b0cd343..95fa85848dcc9 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -2497,7 +2497,7 @@ def test_count_number_queries(self, tasks_count): triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} for i in range(tasks_count): EmptyOperator(task_id=f"dummy_task_{i}", owner="test", dag=dag) - with assert_queries_count(3): + with assert_queries_count(4): dag.create_dagrun( run_id="test_dagrun_query_count", state=State.RUNNING,