diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py index ad774f647..e16da3c5b 100644 --- a/mmrazor/engine/runner/utils/check.py +++ b/mmrazor/engine/runner/utils/check.py @@ -34,10 +34,13 @@ def check_subnet_resources( _, sliced_model = export_fix_subnet(model, slice_weight=True) model_to_check = sliced_model.architecture # type: ignore + measure_latency = True if 'latency' in constraints_range.keys() else False if isinstance(model_to_check, BaseDetector): - results = estimator.estimate(model=model_to_check.backbone) + results = estimator.estimate(model=model_to_check.backbone, + measure_latency=measure_latency) else: - results = estimator.estimate(model=model_to_check) + results = estimator.estimate(model=model_to_check, + measure_latency=measure_latency) for k, v in constraints_range.items(): if not isinstance(v, (list, tuple)): diff --git a/mmrazor/models/task_modules/estimators/resource_estimator.py b/mmrazor/models/task_modules/estimators/resource_estimator.py index ac5292d0c..45716c3a4 100644 --- a/mmrazor/models/task_modules/estimators/resource_estimator.py +++ b/mmrazor/models/task_modules/estimators/resource_estimator.py @@ -92,10 +92,12 @@ def __init__( self.flops_params_cfg = dict() self.latency_cfg = latency_cfg if latency_cfg else dict() - def estimate(self, - model: torch.nn.Module, - flops_params_cfg: dict = None, - latency_cfg: dict = None) -> Dict[str, Union[float, str]]: + def estimate( + self, + model: torch.nn.Module, + flops_params_cfg: dict = None, + latency_cfg: dict = None, + measure_latency: bool = False) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the given model. This method will first parse the merged :attr:`self.flops_params_cfg` @@ -106,6 +108,7 @@ def estimate(self, flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. Default to None. latency_cfg (dict): Cfg for estimating latency. Default to None. + measure_latency (bool): Measure latency or not. Default to False. NOTE: If the `flops_params_cfg` and `latency_cfg` are both None, this method will only estimate FLOPs/params with default settings. @@ -115,7 +118,6 @@ def estimate(self, results(FLOPs, params and latency). """ resource_metrics = dict() - measure_latency = True if latency_cfg else False if flops_params_cfg: flops_params_cfg = {**self.flops_params_cfg, **flops_params_cfg}