Skip to content

Commit

Permalink
PolicyDefinition make parseDefinition non-static and adjust accesses
Browse files Browse the repository at this point in the history
Signed-off-by: kingthorin <kingthorin@users.noreply.github.com>
  • Loading branch information
kingthorin committed Nov 15, 2024
1 parent 381c948 commit e79bd6f
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ public void verifyParameters(AutomationProgress progress) {
params, this.parameters, this.getName(), null, progress);
break;
case "policyDefinition":
// Parse the policy defn
PolicyDefinition.parsePolicyDefinition(
jobData.get(key), policyDefinition, this.getName(), progress);
break;
case "name":
case "tests":
case "type":
Expand All @@ -115,7 +111,8 @@ public void verifyParameters(AutomationProgress progress) {
break;
}
}

policyDefinition.parsePolicyDefinition(
jobData.get("policyDefinition"), this.getName(), progress);
this.verifyUser(this.getParameters().getUser(), progress);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ public void verifyParameters(AutomationProgress progress) {
break;
case "policyDefinition":
// Parse the policy defn
PolicyDefinition.parsePolicyDefinition(
jobData.get(key), policyDefinition, this.getName(), progress);
policyDefinition.parsePolicyDefinition(
jobData.get(key), this.getName(), progress);
break;
case "name":
case "tests":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import lombok.Getter;
import lombok.Setter;
import org.parosproxy.paros.Constant;
Expand All @@ -37,39 +38,47 @@
@Setter
public class PolicyDefinition extends AutomationData {

private static final String DEFAULT_STRENGTH_KEY = "defaultStrength";
private static final String DEFAULT_THRESHOLD_KEY = "defaultThreshold";

protected static final String RULES_ELEMENT_NAME = "rules";

private String defaultStrength = JobUtils.strengthToI18n(AttackStrength.MEDIUM.name());
private String defaultThreshold = JobUtils.thresholdToI18n(AlertThreshold.MEDIUM.name());
private List<Rule> rules = new ArrayList<>();

public static void parsePolicyDefinition(
Object policyDefnObj,
PolicyDefinition policyDefinition,
String jobName,
AutomationProgress progress) {
public void parsePolicyDefinition(
Object policyDefnObj, String jobName, AutomationProgress progress) {

if (policyDefnObj == null) {
this.defaultStrength = null;
return;
}
if (policyDefnObj instanceof LinkedHashMap<?, ?>) {
LinkedHashMap<?, ?> policyDefnData = (LinkedHashMap<?, ?>) policyDefnObj;
@SuppressWarnings("unchecked")
LinkedHashMap<Object, Object> policyDefnData =
(LinkedHashMap<Object, Object>) policyDefnObj;

checkAndSetDefault(policyDefnData, DEFAULT_STRENGTH_KEY, AttackStrength.MEDIUM.name());
checkAndSetDefault(policyDefnData, DEFAULT_THRESHOLD_KEY, AlertThreshold.MEDIUM.name());

if (policyDefnData.isEmpty()) {
policyDefinition.setDefaultStrength(null);
if (policyDefnData.isEmpty() || undefinedDefinition(policyDefnData)) {
this.defaultStrength = null;
return;
}

JobUtils.applyParamsToObject(
policyDefnData,
policyDefinition,
this,
jobName,
new String[] {PolicyDefinition.RULES_ELEMENT_NAME},
progress);

List<Rule> rules = new ArrayList<>();
this.rules = new ArrayList<>();
ScanPolicy scanPolicy = new ScanPolicy();
PluginFactory pluginFactory = scanPolicy.getPluginFactory();

Object o = policyDefnData.get(RULES_ELEMENT_NAME);

if (o instanceof ArrayList<?>) {
ArrayList<?> ruleData = (ArrayList<?>) o;
for (Object ruleObj : ruleData) {
Expand All @@ -94,7 +103,7 @@ public static void parsePolicyDefinition(
if (strength != null) {
rule.setStrength(strength.name().toLowerCase());
}
rules.add(rule);
this.rules.add(rule);

} else {
progress.warn(
Expand All @@ -103,7 +112,6 @@ public static void parsePolicyDefinition(
}
}
}
policyDefinition.setRules(rules);
} else if (o != null) {
progress.warn(
Constant.messages.getString(
Expand All @@ -122,6 +130,26 @@ public static void parsePolicyDefinition(
}
}

private static void checkAndSetDefault(
LinkedHashMap<Object, Object> policyDefnData, String key, String value) {
if (policyDefnData.containsKey(key) && policyDefnData.get(key) == null) {
policyDefnData.put(key, value);
}
}

private static boolean undefinedDefinition(Map<?, ?> policyDefnData) {
Object rules = policyDefnData.get(RULES_ELEMENT_NAME);
boolean rulesInvalid = false;
if (rules instanceof List<?>) {
rulesInvalid = ((List<?>) rules).isEmpty();
} else if ((String) rules == null) {
rulesInvalid = true;
}
return (String) policyDefnData.get(DEFAULT_STRENGTH_KEY) == null
&& (String) policyDefnData.get(DEFAULT_THRESHOLD_KEY) == null
&& rulesInvalid;
}

public ScanPolicy getScanPolicy(String jobName, AutomationProgress progress) {
if (getDefaultStrength() == null) {
// Nothing defined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;

import java.io.IOException;
import java.nio.file.Files;
Expand All @@ -33,17 +35,23 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.parosproxy.paros.CommandLine;
import org.parosproxy.paros.Constant;
import org.parosproxy.paros.core.scanner.AbstractPlugin;
import org.parosproxy.paros.core.scanner.Plugin;
import org.parosproxy.paros.core.scanner.Plugin.AlertThreshold;
import org.parosproxy.paros.core.scanner.Plugin.AttackStrength;
import org.parosproxy.paros.core.scanner.PluginFactory;
import org.parosproxy.paros.core.scanner.PluginFactoryTestHelper;
import org.parosproxy.paros.core.scanner.PluginTestHelper;
import org.yaml.snakeyaml.Yaml;
import org.zaproxy.addon.automation.AutomationProgress;
import org.zaproxy.addon.automation.jobs.PolicyDefinition.Rule;
import org.zaproxy.zap.extension.ascan.ScanPolicy;
import org.zaproxy.zap.utils.I18N;

class PolicyDefinitionUnitTest {
Expand Down Expand Up @@ -189,7 +197,7 @@ void shouldParseValidDefinition() {
Object data = yaml.load(yamlStr);

// When
PolicyDefinition.parsePolicyDefinition(data, policyDefinition, "test", progress);
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
Expand Down Expand Up @@ -223,7 +231,7 @@ void shouldWarnIfUnknownRule() {
Object data = yaml.load(yamlStr);

// When
PolicyDefinition.parsePolicyDefinition(data, policyDefinition, "test", progress);
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
Expand All @@ -250,7 +258,7 @@ void shouldWarnIfDefnNotList() {
Object data = yaml.load(yamlStr);

// When
PolicyDefinition.parsePolicyDefinition(data, policyDefinition, "test", progress);
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
Expand All @@ -270,7 +278,7 @@ void shouldWarnIfRulesNotList() {
Object data = yaml.load(yamlStr);

// When
PolicyDefinition.parsePolicyDefinition(data, policyDefinition, "test", progress);
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
Expand All @@ -279,4 +287,66 @@ void shouldWarnIfRulesNotList() {
assertThat(
progress.getWarnings().get(0), is(equalTo("!automation.error.options.badlist!")));
}

@ParameterizedTest
@ValueSource(
strings = {
" defaultStrength: \n" + " defaultThreshold: \n" + " rules: ",
"defaultStrength:",
"defaultThreshold:"
})
void shouldReturnPolicyWithDefaultsIfDefinitionYamlContainsUndefinedStrengthThreshold(
String defnYamlStr) {
// Given
AutomationProgress progress = new AutomationProgress();
Yaml yaml = new Yaml();
Object data = yaml.load(defnYamlStr);

// When
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
assertThat(progress.hasWarnings(), is(equalTo(false)));
ScanPolicy policy = policyDefinition.getScanPolicy("test", progress);
assertThat(policy, is(notNullValue()));
assertThat(policy.getDefaultStrength(), is(equalTo(AttackStrength.MEDIUM)));
assertThat(policy.getDefaultThreshold(), is(equalTo(AlertThreshold.MEDIUM)));
List<Plugin> rules = policy.getPluginFactory().getAllPlugin();
assertValueAppliedToRules(
rules.get(0),
rules.get(rules.size() - 1),
AttackStrength.MEDIUM,
AlertThreshold.MEDIUM);
}

private static void assertValueAppliedToRules(
Plugin first, Plugin last, AttackStrength expectedStr, AlertThreshold expectedThold) {
assertThat(first.getAttackStrength(), is(equalTo(expectedStr)));
assertThat(last.getAttackStrength(), is(equalTo(expectedStr)));
assertThat(first.getAlertThreshold(), is(equalTo(expectedThold)));
assertThat(last.getAlertThreshold(), is(equalTo(expectedThold)));
}

@ParameterizedTest
@ValueSource(
strings = {
"{}",
"",
"rules: \n",
})
void shouldReturnNullPolicyIfDefinitionYamlIsEmptyOrNullObject(String defnYamlStr) {
// Given
AutomationProgress progress = new AutomationProgress();
Yaml yaml = new Yaml();
Object data = yaml.load(defnYamlStr);

// When
policyDefinition.parsePolicyDefinition(data, "test", progress);

// Then
assertThat(progress.hasErrors(), is(equalTo(false)));
assertThat(progress.hasWarnings(), is(equalTo(false)));
assertThat(policyDefinition.getScanPolicy("test", progress), is(nullValue()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,6 @@ public void verifyParameters(AutomationProgress progress) {
params, this.parameters, this.getName(), null, progress);
break;
case "policyDefinition":
// Parse the policy defn
policyDefinition.parsePolicyDefinition(
jobData.get(key), policyDefinition, this.getName(), progress);
break;
case "name":
case "tests":
case "type":
Expand All @@ -134,6 +130,8 @@ public void verifyParameters(AutomationProgress progress) {
break;
}
}
policyDefinition.parsePolicyDefinition(
jobData.get("policyDefinition"), this.getName(), progress);

this.verifyUser(this.getParameters().getUser(), progress);
}
Expand Down

0 comments on commit e79bd6f

Please sign in to comment.