Interface and implementations of the Dask Task Runner.
Task Runners
in Prefect are responsible for managing the execution of Prefect task runs.
Generally speaking, users are not expected to interact with
task runners outside of configuring and initializing them for a flow.
A parallel task_runner that submits tasks to the dask.distributed scheduler.
By default a temporary distributed.LocalCluster is created (and
subsequently torn down) within the start() contextmanager. To use a
different cluster class (e.g.
dask_kubernetes.KubeCluster), you can
specify cluster_class/cluster_kwargs.
Alternatively, if you already have a dask cluster running, you can provide
the cluster object via the cluster kwarg or the address of the scheduler
via the address kwarg.
Multiprocessing safety
Note that, because the DaskTaskRunner uses multiprocessing, calls to flows
in scripts must be guarded with if __name__ == "__main__": or warnings will
be displayed.
Parameters:
Name
Type
Description
Default
cluster
Cluster
Currently running dask cluster;
if one is not provider (or specified via address kwarg), a temporary
cluster will be created in DaskTaskRunner.start(). Defaults to None.
None
address
string
Address of a currently running dask
scheduler. Defaults to None.
None
cluster_class
string or callable
The cluster class to use
when creating a temporary dask cluster. Can be either the full
class name (e.g. "distributed.LocalCluster"), or the class itself.
None
cluster_kwargs
dict
Additional kwargs to pass to the
cluster_class when creating a temporary dask cluster.
None
adapt_kwargs
dict
Additional kwargs to pass to cluster.adapt
when creating a temporary dask cluster. Note that adaptive scaling
is only enabled if adapt_kwargs are provided.
classDaskTaskRunner(BaseTaskRunner):""" A parallel task_runner that submits tasks to the `dask.distributed` scheduler. By default a temporary `distributed.LocalCluster` is created (and subsequently torn down) within the `start()` contextmanager. To use a different cluster class (e.g. [`dask_kubernetes.KubeCluster`](https://kubernetes.dask.org/)), you can specify `cluster_class`/`cluster_kwargs`. Alternatively, if you already have a dask cluster running, you can provide the cluster object via the `cluster` kwarg or the address of the scheduler via the `address` kwarg. !!! warning "Multiprocessing safety" Note that, because the `DaskTaskRunner` uses multiprocessing, calls to flows in scripts must be guarded with `if __name__ == "__main__":` or warnings will be displayed. Args: cluster (distributed.deploy.Cluster, optional): Currently running dask cluster; if one is not provider (or specified via `address` kwarg), a temporary cluster will be created in `DaskTaskRunner.start()`. Defaults to `None`. address (string, optional): Address of a currently running dask scheduler. Defaults to `None`. cluster_class (string or callable, optional): The cluster class to use when creating a temporary dask cluster. Can be either the full class name (e.g. `"distributed.LocalCluster"`), or the class itself. cluster_kwargs (dict, optional): Additional kwargs to pass to the `cluster_class` when creating a temporary dask cluster. adapt_kwargs (dict, optional): Additional kwargs to pass to `cluster.adapt` when creating a temporary dask cluster. Note that adaptive scaling is only enabled if `adapt_kwargs` are provided. client_kwargs (dict, optional): Additional kwargs to use when creating a [`dask.distributed.Client`](https://distributed.dask.org/en/latest/api.html#client). Examples: Using a temporary local dask cluster: ```python from prefect import flow from prefect_dask.task_runners import DaskTaskRunner @flow(task_runner=DaskTaskRunner) def my_flow(): ... ``` Using a temporary cluster running elsewhere. Any Dask cluster class should work, here we use [dask-cloudprovider](https://cloudprovider.dask.org): ```python DaskTaskRunner( cluster_class="dask_cloudprovider.FargateCluster", cluster_kwargs={ "image": "prefecthq/prefect:latest", "n_workers": 5, }, ) ``` Connecting to an existing dask cluster: ```python DaskTaskRunner(address="192.0.2.255:8786") ``` """def__init__(self,cluster:Optional[distributed.deploy.Cluster]=None,address:str=None,cluster_class:Union[str,Callable]=None,cluster_kwargs:dict=None,adapt_kwargs:dict=None,client_kwargs:dict=None,):# Validate settings and infer defaultsifaddress:ifclusterorcluster_classorcluster_kwargsoradapt_kwargs:raiseValueError("Cannot specify `address` and ""`cluster`/`cluster_class`/`cluster_kwargs`/`adapt_kwargs`")elifcluster:ifcluster_classorcluster_kwargs:raiseValueError("Cannot specify `cluster` and `cluster_class`/`cluster_kwargs`")ifnotcluster.asynchronous:raiseValueError("The cluster must have `asynchronous=True` to be ""used with `DaskTaskRunner`.")else:ifisinstance(cluster_class,str):cluster_class=from_qualified_name(cluster_class)else:cluster_class=cluster_class# Create a copies of incoming kwargs since we may mutate themcluster_kwargs=cluster_kwargs.copy()ifcluster_kwargselse{}adapt_kwargs=adapt_kwargs.copy()ifadapt_kwargselse{}client_kwargs=client_kwargs.copy()ifclient_kwargselse{}# Update kwargs defaultsclient_kwargs.setdefault("set_as_default",False)# The user cannot specify async/sync themselvesif"asynchronous"inclient_kwargs:raiseValueError("`client_kwargs` cannot set `asynchronous`. ""This option is managed by Prefect.")if"asynchronous"incluster_kwargs:raiseValueError("`cluster_kwargs` cannot set `asynchronous`. ""This option is managed by Prefect.")# Store settingsself.address=addressself.cluster_class=cluster_classself.cluster_kwargs=cluster_kwargsself.adapt_kwargs=adapt_kwargsself.client_kwargs=client_kwargs# Runtime attributesself._client:"distributed.Client"=Noneself._cluster:"distributed.deploy.Cluster"=clusterself._dask_futures:Dict[str,"distributed.Future"]={}super().__init__()@propertydefconcurrency_type(self)->TaskConcurrencyType:return(TaskConcurrencyType.PARALLELifself.cluster_kwargs.get("processes")elseTaskConcurrencyType.CONCURRENT)defduplicate(self):""" Create a new instance of the task runner with the same settings. """returntype(self)(address=self.address,cluster_class=self.cluster_class,cluster_kwargs=self.cluster_kwargs,adapt_kwargs=self.adapt_kwargs,client_kwargs=self.client_kwargs,)def__eq__(self,other:object)->bool:""" Check if an instance has the same settings as this task runner. """iftype(self)==type(other):return(self.address==other.addressandself.cluster_class==other.cluster_classandself.cluster_kwargs==other.cluster_kwargsandself.adapt_kwargs==other.adapt_kwargsandself.client_kwargs==other.client_kwargs)else:returnNotImplementedasyncdefsubmit(self,key:UUID,call:Callable[...,Awaitable[State[R]]],)->None:ifnotself._started:raiseRuntimeError("The task runner must be started before submitting work.")# unpack the upstream call in order to cast Prefect futures to Dask futures# where possible to optimize Dask task schedulingcall_kwargs=self._optimize_futures(call.keywords)if"task_run"incall_kwargs:task_run=call_kwargs["task_run"]flow_run=FlowRunContext.get().flow_run# Dask displays the text up to the first '-' as the name; the task run key# should include the task run name for readability in the Dask console.# For cases where the task run fails and reruns for a retried flow run,# the flow run count is included so that the new key will not match# the failed run's key, therefore not retrieving from the Dask cache.dask_key=f"{task_run.name}-{task_run.id.hex}-{flow_run.run_count}"else:dask_key=str(key)self._dask_futures[key]=self._client.submit(call.func,key=dask_key,# Dask defaults to treating functions are pure, but we set this here for# explicit expectations. If this task run is submitted to Dask twice, the# result of the first run should be returned. Subsequent runs would return# `Abort` exceptions if they were submitted again.pure=True,**call_kwargs,)def_get_dask_future(self,key:UUID)->"distributed.Future":""" Retrieve the dask future corresponding to a Prefect future. The Dask future is for the `run_fn`, which should return a `State`. """returnself._dask_futures[key]def_optimize_futures(self,expr):defvisit_fn(expr):ifisinstance(expr,PrefectFuture):dask_future=self._dask_futures.get(expr.key)ifdask_futureisnotNone:returndask_future# Fallback to return the expression unalteredreturnexprreturnvisit_collection(expr,visit_fn=visit_fn,return_data=True)asyncdefwait(self,key:UUID,timeout:float=None)->Optional[State]:future=self._get_dask_future(key)try:returnawaitfuture.result(timeout=timeout)exceptdistributed.TimeoutError:returnNoneexceptBaseExceptionasexc:returnawaitexception_to_crashed_state(exc)asyncdef_start(self,exit_stack:AsyncExitStack):""" Start the task runner and prep for context exit. - Creates a cluster if an external address is not set. - Creates a client to connect to the cluster. - Pushes a call to wait for all running futures to complete on exit. """ifself._cluster:self.logger.info(f"Connecting to existing Dask cluster {self._cluster}")self._connect_to=self._clusterifself.adapt_kwargs:self._cluster.adapt(**self.adapt_kwargs)elifself.address:self.logger.info(f"Connecting to an existing Dask cluster at {self.address}")self._connect_to=self.addresselse:self.cluster_class=self.cluster_classordistributed.LocalClusterself.logger.info(f"Creating a new Dask cluster with "f"`{to_qualified_name(self.cluster_class)}`")self._connect_to=self._cluster=awaitexit_stack.enter_async_context(self.cluster_class(asynchronous=True,**self.cluster_kwargs))ifself.adapt_kwargs:adapt_response=self._cluster.adapt(**self.adapt_kwargs)ifinspect.isawaitable(adapt_response):awaitadapt_responseself._client=awaitexit_stack.enter_async_context(distributed.Client(self._connect_to,asynchronous=True,**self.client_kwargs))ifself._client.dashboard_link:self.logger.info(f"The Dask dashboard is available at {self._client.dashboard_link}",)def__getstate__(self):""" Allow the `DaskTaskRunner` to be serialized by dropping the `distributed.Client`, which contains locks. Must be deserialized on a dask worker. """data=self.__dict__.copy()data.update({k:Noneforkin{"_client","_cluster","_connect_to"}})returndatadef__setstate__(self,data:dict):""" Restore the `distributed.Client` by loading the client on a dask worker. """self.__dict__.update(data)self._client=distributed.get_client()
Create a new instance of the task runner with the same settings.
Source code in prefect_dask/task_runners.py
227228229230231232233234235236237
defduplicate(self):""" Create a new instance of the task runner with the same settings. """returntype(self)(address=self.address,cluster_class=self.cluster_class,cluster_kwargs=self.cluster_kwargs,adapt_kwargs=self.adapt_kwargs,client_kwargs=self.client_kwargs,)