diff --git a/src/correction.cc b/src/correction.cc index b41c738..1c7f52f 100644 --- a/src/correction.cc +++ b/src/correction.cc @@ -130,12 +130,18 @@ namespace { const std::vector& values; }; - std::size_t find_bin_idx(double value, + std::size_t find_bin_idx(Variable::Type value_variant, const std::variant<_UniformBins, _NonUniformBins> &bins_, const _FlowBehavior &flow, std::size_t variableIdx, const char *name) { + double value = std::visit([](auto&& arg) -> double { + using T = std::decay_t; + if constexpr (std::is_same_v) return static_cast(arg); + else if constexpr (std::is_same_v) return arg; + else throw std::logic_error("I should not have ever seen a string"); + }, value_variant); if ( auto *bins = std::get_if<_UniformBins>(&bins_) ) { // uniform binning if (value < bins->low || value >= bins->high) { switch (flow) { @@ -187,7 +193,7 @@ namespace { return binIdx; } - size_t input_index(const std::string_view name, const std::vector &inputs) { + size_t find_input_index(const std::string_view name, const std::vector &inputs) { size_t idx = 0; for (const auto& var : inputs) { if ( name == var.name() ) return idx; @@ -287,7 +293,7 @@ Formula::Formula(const JSONObject& json, const std::vector& inputs, bo std::vector variableIdx; for (const auto& item : json.getRequired("variables")) { - auto idx = input_index(item.GetString(), inputs); + auto idx = find_input_index(item.GetString(), inputs); if ( inputs[idx].type() != Variable::VarType::real ) { throw std::runtime_error("Formulas only accept real-valued inputs, got type " + inputs[idx].typeStr() + " for variable " + inputs[idx].name()); @@ -341,7 +347,7 @@ double FormulaRef::evaluate(const std::vector& values) const { } Transform::Transform(const JSONObject& json, const Correction& context) { - variableIdx_ = input_index(json.getRequired("input"), context.inputs()); + variableIdx_ = find_input_index(json.getRequired("input"), context.inputs()); const auto& variable = context.inputs()[variableIdx_]; if ( variable.type() == Variable::VarType::string ) { throw std::runtime_error("Transform cannot rewrite string inputs"); @@ -372,7 +378,7 @@ HashPRNG::HashPRNG(const JSONObject& json, const Correction& context) variablesIdx_.reserve(inputs.Size()); for (const auto& input : inputs) { if ( ! input.IsString() ) { throw std::runtime_error("invalid hashprng input type"); } - size_t idx = input_index(input.GetString(), context.inputs()); + size_t idx = find_input_index(input.GetString(), context.inputs()); if ( context.inputs().at(idx).type() == Variable::VarType::string ) { throw std::runtime_error("HashPRNG cannot use string inputs as entropy sources"); } @@ -449,7 +455,10 @@ Binning::Binning(const JSONObject& json, const Correction& context) throw std::runtime_error ("Error when processing Binning: edges are neither an array nor a UniformBinning object"); } - variableIdx_ = input_index(json.getRequired("input"), context.inputs()); + variableIdx_ = find_input_index(json.getRequired("input"), context.inputs()); + if ( context.inputs().at(variableIdx_).type() == Variable::VarType::string ) { + throw std::runtime_error("Binning cannot use string inputs as binning variables"); + } Content default_value{0.}; const auto& flowbehavior = json.getRequiredValue("flow"); if ( flowbehavior == "clamp" ) { @@ -471,8 +480,7 @@ Binning::Binning(const JSONObject& json, const Correction& context) double Binning::evaluate(const std::vector& values) const { - double value = std::get(values[variableIdx_]); - std::size_t binIdx = find_bin_idx(value, bins_, flow_, variableIdx_, "Binning"); + std::size_t binIdx = find_bin_idx(values[variableIdx_], bins_, flow_, variableIdx_, "Binning"); const Content& child = contents_[binIdx]; return std::visit(node_evaluate{values}, child); } @@ -489,7 +497,11 @@ MultiBinning::MultiBinning(const JSONObject& json, const Correction& context) if ( dimension.IsArray() ) { // non-uniform binning std::vector dim_edges = parse_bin_edges(dimension.GetArray()); if ( ! input.IsString() ) { throw std::runtime_error("invalid multibinning input type"); } - axes_.push_back({input_index(input.GetString(), context.inputs()), 0, _NonUniformBins(std::move(dim_edges))}); + size_t variableIdx = find_input_index(input.GetString(), context.inputs()); + if ( context.inputs().at(variableIdx).type() == Variable::VarType::string ) { + throw std::runtime_error("MultiBinning cannot use string inputs as binning variables"); + } + axes_.push_back({variableIdx, 0, _NonUniformBins(std::move(dim_edges))}); } else if ( dimension.IsObject() ) { // UniformBinning const JSONObject uniformBins{dimension.GetObject()}; const auto n = uniformBins.getRequired("n"); @@ -499,7 +511,11 @@ MultiBinning::MultiBinning(const JSONObject& json, const Correction& context) } const auto low = uniformBins.getRequired("low"); const auto high = uniformBins.getRequired("high"); - axes_.push_back({input_index(input.GetString(), context.inputs()), 0, _UniformBins{n, low, high}}); + size_t variableIdx = find_input_index(input.GetString(), context.inputs()); + if ( context.inputs().at(variableIdx).type() == Variable::VarType::string ) { + throw std::runtime_error("MultiBinning cannot use string inputs as binning variables"); + } + axes_.push_back({variableIdx, 0, _UniformBins{n, low, high}}); } else { auto msg = "Error when processing MultiBinning: edges for dimension " + std::to_string(idx) + " are neither an array nor a UniformBinning object"; throw std::runtime_error (std::move(msg)); @@ -544,8 +560,7 @@ double MultiBinning::evaluate(const std::vector& values) const size_t dim {0}; for (const auto& [variableIdx, stride, edgesVariant] : axes_) { - double value = std::get(values[variableIdx]); - localidx = find_bin_idx(value, edgesVariant, flow_, variableIdx, "MultiBinning"); + localidx = find_bin_idx(values[variableIdx], edgesVariant, flow_, variableIdx, "MultiBinning"); if ( localidx == nbins(dim) ) // find_bin_idx is indicating we need to return the default value return std::visit(node_evaluate{values}, content_.back()); idx += localidx * stride; @@ -568,7 +583,7 @@ size_t MultiBinning::nbins(size_t dimension) const Category::Category(const JSONObject& json, const Correction& context) { - variableIdx_ = input_index(json.getRequired("input"), context.inputs()); + variableIdx_ = find_input_index(json.getRequired("input"), context.inputs()); const auto& variable = context.inputs()[variableIdx_]; if ( variable.type() == Variable::VarType::string ) { map_ = StrMap(); diff --git a/tests/test_issue217.py b/tests/test_issue217.py new file mode 100644 index 0000000..c1fdc22 --- /dev/null +++ b/tests/test_issue217.py @@ -0,0 +1,44 @@ +import pytest + +import correctionlib.schemav2 as cs + + +def test_issue217(): + content = [1.1, 1.08, 1.06, 1.04, 1.02, 1.0] + corr = cs.Correction( + name="NJetweight", + version=1, + inputs=[cs.Variable(name="nJets", type="int", description="Number of jets")], + output=cs.Variable( + name="weight", type="real", description="Multiplicative event weight" + ), + data=cs.Binning( + nodetype="binning", + input="nJets", + edges=[0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5], + content=content, + flow="clamp", + ), + ) + ceval = corr.to_evaluator() + assert [ceval.evaluate(i) for i in range(1, 7)] == content + + +def test_binning_invalidinput(): + corr = cs.Correction( + name="NJetweight", + version=1, + inputs=[cs.Variable(name="bogus", type="string")], + output=cs.Variable( + name="weight", type="real", description="Multiplicative event weight" + ), + data=cs.Binning( + nodetype="binning", + input="bogus", + edges=[0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5], + content=[1.1, 1.08, 1.06, 1.04, 1.02, 1.0], + flow="clamp", + ), + ) + with pytest.raises(RuntimeError): + corr.to_evaluator()