Skip to content

Commit

Permalink
Merge pull request #5897 from kingthorin/seq-policy2
Browse files Browse the repository at this point in the history
automation & sequence: fix use of default policies and set Sequence as default sequence policy
  • Loading branch information
thc202 authored Nov 15, 2024
2 parents 603db01 + e79bd6f commit e26cfb0
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 33 deletions.
1 change: 1 addition & 0 deletions addOns/automation/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
### Fixed
- Templates generated with `-autogenmin` or `-autogenmax` were invalid in some cases.
- Allow to choose one thread for the `activeScan` job through the GUI.
- Active Scan jobs will once again use the default policy if neither a policy nor a policyDefinition has been set.

## [0.43.0] - 2024-10-07
### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ public void verifyParameters(AutomationProgress progress) {
params, this.parameters, this.getName(), null, progress);
break;
case "policyDefinition":
// Parse the policy defn
PolicyDefinition.parsePolicyDefinition(
jobData.get(key), policyDefinition, this.getName(), progress);
break;
case "name":
case "tests":
case "type":
Expand All @@ -115,7 +111,8 @@ public void verifyParameters(AutomationProgress progress) {
break;
}
}

policyDefinition.parsePolicyDefinition(
jobData.get("policyDefinition"), this.getName(), progress);
this.verifyUser(this.getParameters().getUser(), progress);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ public void verifyParameters(AutomationProgress progress) {
break;
case "policyDefinition":
// Parse the policy defn
PolicyDefinition.parsePolicyDefinition(
jobData.get(key), policyDefinition, this.getName(), progress);
policyDefinition.parsePolicyDefinition(
jobData.get(key), this.getName(), progress);
break;
case "name":
case "tests":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import lombok.Getter;
import lombok.Setter;
import org.parosproxy.paros.Constant;
Expand All @@ -31,36 +32,49 @@
import org.parosproxy.paros.core.scanner.PluginFactory;
import org.zaproxy.addon.automation.AutomationData;
import org.zaproxy.addon.automation.AutomationProgress;
import org.zaproxy.addon.automation.jobs.PolicyDefinition.Rule;
import org.zaproxy.zap.extension.ascan.ScanPolicy;

@Getter
@Setter
public class PolicyDefinition extends AutomationData {

private static final String DEFAULT_STRENGTH_KEY = "defaultStrength";
private static final String DEFAULT_THRESHOLD_KEY = "defaultThreshold";

protected static final String RULES_ELEMENT_NAME = "rules";

private String defaultStrength = JobUtils.strengthToI18n(AttackStrength.MEDIUM.name());
private String defaultThreshold = JobUtils.thresholdToI18n(AlertThreshold.MEDIUM.name());
private List<Rule> rules = new ArrayList<>();

public static void parsePolicyDefinition(
Object policyDefnObj,
PolicyDefinition policyDefinition,
String jobName,
AutomationProgress progress) {
public void parsePolicyDefinition(
Object policyDefnObj, String jobName, AutomationProgress progress) {

if (policyDefnObj == null) {
this.defaultStrength = null;
return;
}
if (policyDefnObj instanceof LinkedHashMap<?, ?>) {
LinkedHashMap<?, ?> policyDefnData = (LinkedHashMap<?, ?>) policyDefnObj;
@SuppressWarnings("unchecked")
LinkedHashMap<Object, Object> policyDefnData =
(LinkedHashMap<Object, Object>) policyDefnObj;

checkAndSetDefault(policyDefnData, DEFAULT_STRENGTH_KEY, AttackStrength.MEDIUM.name());
checkAndSetDefault(policyDefnData, DEFAULT_THRESHOLD_KEY, AlertThreshold.MEDIUM.name());

if (policyDefnData.isEmpty() || undefinedDefinition(policyDefnData)) {
this.defaultStrength = null;
return;
}

JobUtils.applyParamsToObject(
policyDefnData,
policyDefinition,
this,
jobName,
new String[] {PolicyDefinition.RULES_ELEMENT_NAME},
progress);

List<Rule> rules = new ArrayList<>();
this.rules = new ArrayList<>();
ScanPolicy scanPolicy = new ScanPolicy();
PluginFactory pluginFactory = scanPolicy.getPluginFactory();

Expand Down Expand Up @@ -89,7 +103,7 @@ public static void parsePolicyDefinition(
if (strength != null) {
rule.setStrength(strength.name().toLowerCase());
}
rules.add(rule);
this.rules.add(rule);

} else {
progress.warn(
Expand All @@ -98,7 +112,6 @@ public static void parsePolicyDefinition(
}
}
}
policyDefinition.setRules(rules);
} else if (o != null) {
progress.warn(
Constant.messages.getString(
Expand All @@ -117,7 +130,32 @@ public static void parsePolicyDefinition(
}
}

private static void checkAndSetDefault(
LinkedHashMap<Object, Object> policyDefnData, String key, String value) {
if (policyDefnData.containsKey(key) && policyDefnData.get(key) == null) {
policyDefnData.put(key, value);
}
}

private static boolean undefinedDefinition(Map<?, ?> policyDefnData) {
Object rules = policyDefnData.get(RULES_ELEMENT_NAME);
boolean rulesInvalid = false;
if (rules instanceof List<?>) {
rulesInvalid = ((List<?>) rules).isEmpty();
} else if ((String) rules == null) {
rulesInvalid = true;
}
return (String) policyDefnData.get(DEFAULT_STRENGTH_KEY) == null
&& (String) policyDefnData.get(DEFAULT_THRESHOLD_KEY) == null
&& rulesInvalid;
}

public ScanPolicy getScanPolicy(String jobName, AutomationProgress progress) {
if (getDefaultStrength() == null) {
// Nothing defined
return null;
}

ScanPolicy scanPolicy = new ScanPolicy();

// Set default strength
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ <H2>YAML</H2>
threshold: # String: The Alert Threshold for this rule, one of Off, Low, Medium, High, default: Medium
</pre>

<strong>Note</strong>: Unless the <code>defaultThreshold</code> of the <code>policyDefinition</code> is <code>OFF</code> all rules will be enabled to start with.

<p>
The policy can be one defined by a previous <a href="job-ascanpolicy.html">activeScan-policy</a> job, or by a scan policy file
that has been put in <code>policies</code> directory under ZAP's <a href="https://www.zaproxy.org/faq/what-is-the-default-directory-that-zap-uses/">HOME directory</a> .
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.ArgumentMatcher;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
Expand Down Expand Up @@ -435,7 +437,7 @@ void shouldReturnWarningOnUnexpectedElement() throws MalformedURLException {
}

@Test
void shouldReturnScanPolicyForDefaultData() throws MalformedURLException {
void shouldReturnNullScanPolicyForEmptyData() {
// Given
ActiveScanJob job = new ActiveScanJob();
AutomationProgress progress = new AutomationProgress();
Expand All @@ -447,14 +449,66 @@ void shouldReturnScanPolicyForDefaultData() throws MalformedURLException {
job.verifyParameters(progress);
ScanPolicy policy = job.getData().getPolicyDefinition().getScanPolicy(null, progress);

// Then
assertThat(policy, is(equalTo(null)));
assertThat(progress.hasWarnings(), is(equalTo(false)));
assertThat(progress.hasErrors(), is(equalTo(false)));
}

@ParameterizedTest
@EnumSource(
value = AttackStrength.class,
mode = EnumSource.Mode.EXCLUDE,
names = {"DEFAULT"})
void shouldReturnScanPolicyIfOnlyDefaultStrength(AttackStrength attackStrength) {
// Given
ActiveScanJob job = new ActiveScanJob();
AutomationProgress progress = new AutomationProgress();
LinkedHashMap<String, LinkedHashMap<?, ?>> data = new LinkedHashMap<>();
LinkedHashMap<String, String> policyDefn = new LinkedHashMap<>();
policyDefn.put("defaultStrength", attackStrength.name());
data.put("policyDefinition", policyDefn);

// When
job.setJobData(data);
job.verifyParameters(progress);
ScanPolicy policy = job.getData().getPolicyDefinition().getScanPolicy(null, progress);

// Then
assertThat(policy, is(notNullValue()));
assertThat(policy.getDefaultStrength(), is(AttackStrength.MEDIUM));
assertThat(policy.getDefaultStrength(), is(attackStrength));
assertThat(policy.getDefaultThreshold(), is(AlertThreshold.MEDIUM));
assertThat(progress.hasWarnings(), is(equalTo(false)));
assertThat(progress.hasErrors(), is(equalTo(false)));
}

@ParameterizedTest
@EnumSource(
value = AlertThreshold.class,
mode = EnumSource.Mode.EXCLUDE,
names = {"DEFAULT"})
void shouldReturnScanPolicyIfOnlyDefaultThreshold(AlertThreshold alertThreshold) {
// Given
ActiveScanJob job = new ActiveScanJob();
AutomationProgress progress = new AutomationProgress();
LinkedHashMap<String, LinkedHashMap<?, ?>> data = new LinkedHashMap<>();
LinkedHashMap<String, String> policyDefn = new LinkedHashMap<>();
policyDefn.put("defaultThreshold", alertThreshold.name());
data.put("policyDefinition", policyDefn);

// When
job.setJobData(data);
job.verifyParameters(progress);
ScanPolicy policy = job.getData().getPolicyDefinition().getScanPolicy(null, progress);

// Then
assertThat(policy, is(notNullValue()));
assertThat(policy.getDefaultStrength(), is(AttackStrength.MEDIUM));
assertThat(policy.getDefaultThreshold(), is(alertThreshold));
assertThat(progress.hasWarnings(), is(equalTo(false)));
assertThat(progress.hasErrors(), is(equalTo(false)));
}

@Test
void shouldSetScanPolicyDefaults() throws MalformedURLException {
// Given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import org.zaproxy.zap.utils.I18N;
import org.zaproxy.zap.utils.ZapXmlConfiguration;

class ActiveScanPolicyJobPolicyUnitTest {
class ActiveScanPolicyJobUnitTest {

private static MockedStatic<CommandLine> mockedCmdLine;
private ExtensionActiveScan extAScan;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;

import java.io.IOException;
import java.nio.file.Files;
Expand All @@ -33,17 +35,23 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.parosproxy.paros.CommandLine;
import org.parosproxy.paros.Constant;
import org.parosproxy.paros.core.scanner.AbstractPlugin;
import org.parosproxy.paros.core.scanner.Plugin;
import org.parosproxy.paros.core.scanner.Plugin.AlertThreshold;
import org.parosproxy.paros.core.scanner.Plugin.AttackStrength;
import org.parosproxy.paros.core.scanner.PluginFactory;
import org.parosproxy.paros.core.scanner.PluginFactoryTestHelper;
import org.parosproxy.paros.core.scanner.PluginTestHelper;
import org.yaml.snakeyaml.Yaml;
import org.zaproxy.addon.automation.AutomationProgress;
import org.zaproxy.addon.automation.jobs.PolicyDefinition.Rule;
import org.zaproxy.zap.extension.ascan.ScanPolicy;
import org.zaproxy.zap.utils.I18N;

class PolicyDefinitionUnitTest {
Expand Down Expand Up @@ -189,7 +197,7 @@ void shouldParseValidDefinition() {
Object data = yaml.load(yamlStr);

// When
PolicyDefinition.parsePolicyDefinition(data, policyDefinition, "test", progress);
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
Expand Down Expand Up @@ -223,7 +231,7 @@ void shouldWarnIfUnknownRule() {
Object data = yaml.load(yamlStr);

// When
PolicyDefinition.parsePolicyDefinition(data, policyDefinition, "test", progress);
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
Expand All @@ -250,7 +258,7 @@ void shouldWarnIfDefnNotList() {
Object data = yaml.load(yamlStr);

// When
PolicyDefinition.parsePolicyDefinition(data, policyDefinition, "test", progress);
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
Expand All @@ -270,7 +278,7 @@ void shouldWarnIfRulesNotList() {
Object data = yaml.load(yamlStr);

// When
PolicyDefinition.parsePolicyDefinition(data, policyDefinition, "test", progress);
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
Expand All @@ -279,4 +287,66 @@ void shouldWarnIfRulesNotList() {
assertThat(
progress.getWarnings().get(0), is(equalTo("!automation.error.options.badlist!")));
}

@ParameterizedTest
@ValueSource(
strings = {
" defaultStrength: \n" + " defaultThreshold: \n" + " rules: ",
"defaultStrength:",
"defaultThreshold:"
})
void shouldReturnPolicyWithDefaultsIfDefinitionYamlContainsUndefinedStrengthThreshold(
String defnYamlStr) {
// Given
AutomationProgress progress = new AutomationProgress();
Yaml yaml = new Yaml();
Object data = yaml.load(defnYamlStr);

// When
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
assertThat(progress.hasWarnings(), is(equalTo(false)));
ScanPolicy policy = policyDefinition.getScanPolicy("test", progress);
assertThat(policy, is(notNullValue()));
assertThat(policy.getDefaultStrength(), is(equalTo(AttackStrength.MEDIUM)));
assertThat(policy.getDefaultThreshold(), is(equalTo(AlertThreshold.MEDIUM)));
List<Plugin> rules = policy.getPluginFactory().getAllPlugin();
assertValueAppliedToRules(
rules.get(0),
rules.get(rules.size() - 1),
AttackStrength.MEDIUM,
AlertThreshold.MEDIUM);
}

private static void assertValueAppliedToRules(
Plugin first, Plugin last, AttackStrength expectedStr, AlertThreshold expectedThold) {
assertThat(first.getAttackStrength(), is(equalTo(expectedStr)));
assertThat(last.getAttackStrength(), is(equalTo(expectedStr)));
assertThat(first.getAlertThreshold(), is(equalTo(expectedThold)));
assertThat(last.getAlertThreshold(), is(equalTo(expectedThold)));
}

@ParameterizedTest
@ValueSource(
strings = {
"{}",
"",
"rules: \n",
})
void shouldReturnNullPolicyIfDefinitionYamlIsEmptyOrNullObject(String defnYamlStr) {
// Given
AutomationProgress progress = new AutomationProgress();
Yaml yaml = new Yaml();
Object data = yaml.load(defnYamlStr);

// When
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
assertThat(progress.hasWarnings(), is(equalTo(false)));
assertThat(policyDefinition.getScanPolicy("test", progress), is(nullValue()));
}
}
Loading

0 comments on commit e26cfb0

Please sign in to comment.