Skip to content

Commit

Permalink
refactor(local): rename ModelLoader->Provider, ref #222
Browse files Browse the repository at this point in the history
  • Loading branch information
iboB committed Jan 10, 2025
1 parent f03ce22 commit 41a95c0
Show file tree
Hide file tree
Showing 29 changed files with 279 additions and 279 deletions.
18 changes: 9 additions & 9 deletions cmake/ac_local_plugin_util.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ void add_@nameSym@_to_ac_local_global_registry();
// Generated file. Do not edit!
#pragma once
#include "@aclpName@-plib.h"
#include <ac/local/ModelLoaderPtr.hpp>
#include <ac/local/ProviderPtr.hpp>
#include <vector>

namespace ac::local { class ModelLoaderRegistry; }
namespace ac::local { class ProviderRegistry; }

ACLPLIB_@nameSym@_API
const std::vector<ac::local::ModelLoaderPtr>& get_@nameSym@_model_loaders();
const std::vector<ac::local::ProviderPtr>& get_@nameSym@_model_providers();

ACLPLIB_@nameSym@_API
void add_@nameSym@_to_ac_local_registry(ac::local::ModelLoaderRegistry& registry);
void add_@nameSym@_to_ac_local_registry(ac::local::ProviderRegistry& registry);
]=]
@ONLY
)
Expand All @@ -113,15 +113,15 @@ PlibHelper g_helper{ac::@nameSym@::getPluginInterface()};

extern "C"
void add_@nameSym@_to_ac_local_global_registry() {
g_helper.addLoadersToGlobalRegistry();
g_helper.addProvidersToGlobalRegistry();
}

void add_@nameSym@_to_ac_local_registry(ac::local::ModelLoaderRegistry& registry) {
g_helper.addLoadersToRegistry(registry);
void add_@nameSym@_to_ac_local_registry(ac::local::ProviderRegistry& registry) {
g_helper.addProvidersToRegistry(registry);
}

const std::vector<ac::local::ModelLoaderPtr>& get_@nameSym@_model_loaders() {
return g_helper.getLoaders();
const std::vector<ac::local::ProviderPtr>& get_@nameSym@_model_providers() {
return g_helper.getProviders();
}
]=]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace ac::local::schema {

struct DummyLoader {
struct DummyProvider {
static inline constexpr std::string_view id = "dummy";
static inline constexpr std::string_view description = "Dummy inference for tests, examples, and experiments.";

Expand Down
18 changes: 9 additions & 9 deletions dummy-plugin/code/ac/dummy/LocalDummy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: MIT
//
#include "LocalDummy.hpp"
#include "DummyLoaderSchema.hpp"
#include "DummyProviderSchema.hpp"

#include "Instance.hpp"
#include "Model.hpp"
Expand All @@ -11,7 +11,7 @@

#include <ac/local/Instance.hpp>
#include <ac/local/Model.hpp>
#include <ac/local/ModelLoader.hpp>
#include <ac/local/Provider.hpp>

#include <ac/schema/DispatchHelpers.hpp>

Expand All @@ -29,7 +29,7 @@ class DummyInstance final : public Instance {
dummy::Instance m_instance;
schema::OpDispatcherData m_dispatcherData;
public:
using Schema = schema::DummyLoader::InstanceGeneral;
using Schema = schema::DummyProvider::InstanceGeneral;

static dummy::Instance::InitParams InitParams_fromDict(Dict&& d) {
auto schemaParams = schema::Struct_fromDict<Schema::Params>(astl::move(d));
Expand Down Expand Up @@ -78,7 +78,7 @@ class DummyInstance final : public Instance {
class DummyModel final : public Model {
std::shared_ptr<dummy::Model> m_model;
public:
using Schema = schema::DummyLoader;
using Schema = schema::DummyProvider;

static dummy::Model::Params ModelParams_fromDict(Dict& d) {
auto schemaParams = schema::Struct_fromDict<Schema::Params>(std::move(d));
Expand All @@ -103,7 +103,7 @@ class DummyModel final : public Model {
}
};

class DummyModelLoader final : public ModelLoader {
class DummyProvider final : public Provider {
public:
virtual const Info& info() const noexcept override {
static Info i = {
Expand Down Expand Up @@ -147,9 +147,9 @@ class DummyModelLoader final : public ModelLoader {

namespace ac::dummy {

std::vector<ac::local::ModelLoaderPtr> getLoaders() {
std::vector<ac::local::ModelLoaderPtr> ret;
ret.push_back(std::make_unique<local::DummyModelLoader>());
std::vector<ac::local::ProviderPtr> getProviders() {
std::vector<ac::local::ProviderPtr> ret;
ret.push_back(std::make_unique<local::DummyProvider>());
return ret;
}

Expand All @@ -162,7 +162,7 @@ local::PluginInterface getPluginInterface() {
ACLP_dummy_VERSION_MAJOR, ACLP_dummy_VERSION_MINOR, ACLP_dummy_VERSION_PATCH
},
.init = nullptr,
.getLoaders = getLoaders,
.getProviders = getProviders,
};
}

Expand Down
6 changes: 3 additions & 3 deletions dummy-plugin/example/e-gen-dummy-schema.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Copyright (c) Alpaca Core
// SPDX-License-Identifier: MIT
//
#include <ac/dummy/DummyLoaderSchema.hpp>
#include <ac/schema/GenerateLoaderSchemaDict.hpp>
#include <ac/dummy/DummyProviderSchema.hpp>
#include <ac/schema/GenerateProviderSchemaDict.hpp>
#include <iostream>

int main() {
auto d = ac::local::schema::generateLoaderSchema<acnl::ordered_json, ac::local::schema::DummyLoader>();
auto d = ac::local::schema::generateProviderSchema<acnl::ordered_json, ac::local::schema::DummyProvider>();
std::cout << d.dump(2) << std::endl;
return 0;
}
4 changes: 2 additions & 2 deletions dummy-plugin/test/t-dummy-plib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
// SPDX-License-Identifier: MIT
//
#include <aclp-dummy-plib.hpp>
#include <ac/local/ModelLoaderRegistry.hpp>
#include <ac/local/ProviderRegistry.hpp>

#include <ac-test-util/JalogFixture.inl>

struct DummyRegistry : public ac::local::ModelLoaderRegistry {
struct DummyRegistry : public ac::local::ProviderRegistry {
DummyRegistry() {
add_dummy_to_ac_local_registry(*this);
}
Expand Down
10 changes: 5 additions & 5 deletions dummy-plugin/test/t-dummy-plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <aclp-dummy-info.h>
#include <ac/local/Lib.hpp>
#include <ac/local/PluginManager.hpp>
#include <ac/local/ModelLoaderRegistry.hpp>
#include <ac/local/ProviderRegistry.hpp>
#include <doctest/doctest.h>

#include <ac-test-util/JalogFixture.inl>
Expand All @@ -16,11 +16,11 @@ struct GlobalFixture {
};
GlobalFixture globalFixture;

struct DummyRegistry : public ac::local::ModelLoaderRegistry {
struct DummyRegistry : public ac::local::ProviderRegistry {
DummyRegistry() {
auto& loaders = ac::local::Lib::modelLoaderRegistry().loaders();
for (auto& loader : loaders) {
addLoader(*loader.loader);
auto& providers = ac::local::Lib::providerRegistry().providers();
for (auto& provider : providers) {
addProvider(*provider.provider);
}
}
};
Expand Down
10 changes: 5 additions & 5 deletions dummy-plugin/test/t-dummy-schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
#include <ac/local/Lib.hpp>
#include <ac/local/Model.hpp>
#include <ac/local/Instance.hpp>
#include <ac/local/ModelLoaderRegistry.hpp>
#include <ac/local/ProviderRegistry.hpp>

#include <ac/schema/CallHelpers.hpp>

#include <ac/local/PluginPlibUtil.inl>

#include <ac/dummy/LocalDummy.hpp>
#include <ac/dummy/DummyLoaderSchema.hpp>
#include <ac/dummy/DummyProviderSchema.hpp>

#include <ac-test-util/JalogFixture.inl>

Expand All @@ -24,7 +24,7 @@ struct LoadDummyFixture {
LoadDummyFixture()
: helper(ac::dummy::getPluginInterface())
{
helper.addLoadersToGlobalRegistry();
helper.addProvidersToGlobalRegistry();
}
};

Expand All @@ -34,11 +34,11 @@ TEST_CASE("dummy schema") {
auto model = ac::local::Lib::loadModel({
.type = "dummy",
.name = "synthetic"
}, {});
}, {});

REQUIRE(!!model);

using Instance = ac::local::schema::DummyLoader::InstanceGeneral;
using Instance = ac::local::schema::DummyProvider::InstanceGeneral;
auto instance = Model_createInstance<Instance>(*model, {.cutoff = 2});

using Interface = ac::local::schema::DummyInterface;
Expand Down
14 changes: 7 additions & 7 deletions local/code/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ target_sources(ac-local
ac/local/Model.hpp
ac/local/Instance.hpp

ac/local/ModelLoader.hpp
ac/local/ModelLoaderPtr.hpp
ac/local/ModelLoaderRegistry.hpp
ac/local/Provider.hpp
ac/local/ProviderPtr.hpp
ac/local/ProviderRegistry.hpp

ac/local/PluginInterface.hpp
ac/local/PluginInfo.hpp
Expand All @@ -45,22 +45,22 @@ target_sources(ac-local

ac/local/Lib.hpp

ac/local/ModelLoaderScorer.hpp
ac/local/CommonModelLoaderScorers.hpp
ac/local/ProviderScorer.hpp
ac/local/CommonProviderScorers.hpp
PRIVATE
ac/local/Logging.hpp
ac/local/Logging.cpp

ac/local/VtableExports.cpp

ac/local/ModelLoaderRegistry.cpp
ac/local/ProviderRegistry.cpp
ac/local/PluginManager.cpp

ac/local/Session.cpp

ac/local/Lib.cpp

ac/local/CommonModelLoaderScorers.cpp
ac/local/CommonProviderScorers.cpp
)

install(TARGETS ac-local
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
// Copyright (c) Alpaca Core
// SPDX-License-Identifier: MIT
//
#include "CommonModelLoaderScorers.hpp"
#include "ModelLoader.hpp"
#include "CommonProviderScorers.hpp"
#include "Provider.hpp"

namespace ac::local {

ModelLoaderScorer::score_t CanLoadScorer::score(
const ModelLoader& loader,
ProviderScorer::score_t CanLoadScorer::score(
const Provider& provider,
const PluginInfo*,
const ModelAssetDesc& model,
const Dict& params
) const noexcept {
return loader.canLoadModel(model, params);
return provider.canLoadModel(model, params);
}

} // namespace ac::local
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
// SPDX-License-Identifier: MIT
//
#pragma once
#include "ModelLoaderScorer.hpp"
#include "ProviderScorer.hpp"

namespace ac::local {

struct AC_LOCAL_EXPORT BooleanScorer : public ModelLoaderScorer {
struct AC_LOCAL_EXPORT BooleanScorer : public ProviderScorer {
score_t denyScore() const noexcept override final { return 0; }
score_t acceptScore() const noexcept override final { return 1; }
};

struct AC_LOCAL_EXPORT CanLoadScorer final : public BooleanScorer {
score_t score(
const ModelLoader& loader,
const Provider& provider,
const PluginInfo* info,
const ModelAssetDesc& model,
const Dict& params
Expand Down
20 changes: 10 additions & 10 deletions local/code/ac/local/Lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: MIT
//
#include "Lib.hpp"
#include "ModelLoaderRegistry.hpp"
#include "ProviderRegistry.hpp"
#include "PluginManager.hpp"
#include "ModelAssetDesc.hpp"
#include <ac/Dict.hpp>
Expand All @@ -11,24 +11,24 @@
namespace ac::local {

namespace {
ModelLoaderRegistry g_modelLoaderRegistry("global");
PluginManager g_pluginManager(g_modelLoaderRegistry);
ProviderRegistry g_providerRegistry("global");
PluginManager g_pluginManager(g_providerRegistry);
} // namespace

ModelLoaderRegistry& Lib::modelLoaderRegistry() {
return g_modelLoaderRegistry;
ProviderRegistry& Lib::providerRegistry() {
return g_providerRegistry;
}

void Lib::addLoader(ModelLoader& loader) {
g_modelLoaderRegistry.addLoader(loader);
void Lib::addProvider(Provider& provider) {
g_providerRegistry.addProvider(provider);
}

ModelPtr Lib::loadModel(ModelAssetDesc desc, Dict params, ProgressCb cb) {
return g_modelLoaderRegistry.loadModel(astl::move(desc), astl::move(params), astl::move(cb));
return g_providerRegistry.loadModel(astl::move(desc), astl::move(params), astl::move(cb));
}

ModelPtr Lib::loadModel(const ModelLoaderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb) {
return g_modelLoaderRegistry.loadModel(scorer, astl::move(desc), astl::move(params), astl::move(cb));
ModelPtr Lib::loadModel(const ProviderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb) {
return g_providerRegistry.loadModel(scorer, astl::move(desc), astl::move(params), astl::move(cb));
}

PluginManager& Lib::pluginManager() {
Expand Down
12 changes: 6 additions & 6 deletions local/code/ac/local/Lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@
#include <string_view>

namespace ac::local {
class ModelLoaderRegistry;
class ModelLoader;
class ModelLoaderScorer;
class ProviderRegistry;
class Provider;
class ProviderScorer;
struct ModelAssetDesc;

class PluginManager;
struct PluginInfo;

struct AC_LOCAL_EXPORT Lib {
static ModelLoaderRegistry& modelLoaderRegistry();
static ProviderRegistry& providerRegistry();

static void addLoader(ModelLoader& loader);
static void addProvider(Provider& provider);

static ModelPtr loadModel(ModelAssetDesc desc, Dict params, ProgressCb cb = {});
static ModelPtr loadModel(const ModelLoaderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb = {});
static ModelPtr loadModel(const ProviderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb = {});

static PluginManager& pluginManager();

Expand Down
2 changes: 1 addition & 1 deletion local/code/ac/local/ModelAssetDesc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace ac::local {

/// Model asset description. Used by `ModelLoader` to load models.
/// Model asset description. Used by `Provider` to load models.
/// @ingroup cpp-local
struct ModelAssetDesc {
/// Asset (weights) type. May be used by loaders to check whether they can load the model.
Expand Down
Loading

0 comments on commit 41a95c0

Please sign in to comment.