Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Estimate sub-model latency in the process of NAS #568

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions mmrazor/engine/runner/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ def check_subnet_resources(
_, sliced_model = export_fix_subnet(model, slice_weight=True)

model_to_check = sliced_model.architecture # type: ignore
measure_latency = True if 'latency' in constraints_range.keys() else False
if isinstance(model_to_check, BaseDetector):
results = estimator.estimate(model=model_to_check.backbone)
results = estimator.estimate(model=model_to_check.backbone,
measure_latency=measure_latency)
else:
results = estimator.estimate(model=model_to_check)
results = estimator.estimate(model=model_to_check,
measure_latency=measure_latency)

for k, v in constraints_range.items():
if not isinstance(v, (list, tuple)):
Expand Down
12 changes: 7 additions & 5 deletions mmrazor/models/task_modules/estimators/resource_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@ def __init__(
self.flops_params_cfg = dict()
self.latency_cfg = latency_cfg if latency_cfg else dict()

def estimate(self,
model: torch.nn.Module,
flops_params_cfg: dict = None,
latency_cfg: dict = None) -> Dict[str, Union[float, str]]:
def estimate(
self,
model: torch.nn.Module,
flops_params_cfg: dict = None,
latency_cfg: dict = None,
measure_latency: bool = False) -> Dict[str, Union[float, str]]:
"""Estimate the resources(flops/params/latency) of the given model.

This method will first parse the merged :attr:`self.flops_params_cfg`
Expand All @@ -106,6 +108,7 @@ def estimate(self,
flops_params_cfg (dict): Cfg for estimating FLOPs and parameters.
Default to None.
latency_cfg (dict): Cfg for estimating latency. Default to None.
measure_latency (bool): Measure latency or not. Default to False.

NOTE: If the `flops_params_cfg` and `latency_cfg` are both None,
this method will only estimate FLOPs/params with default settings.
Expand All @@ -115,7 +118,6 @@ def estimate(self,
results(FLOPs, params and latency).
"""
resource_metrics = dict()
measure_latency = True if latency_cfg else False

if flops_params_cfg:
flops_params_cfg = {**self.flops_params_cfg, **flops_params_cfg}
Expand Down