From 41a95c0d7d1dec074d55ae87ab8ca4af4fe95f93 Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Fri, 10 Jan 2025 09:04:57 +0200 Subject: [PATCH] refactor(local): rename ModelLoader->Provider, ref #222 --- cmake/ac_local_plugin_util.cmake | 18 ++-- ...aderSchema.hpp => DummyProviderSchema.hpp} | 2 +- dummy-plugin/code/ac/dummy/LocalDummy.cpp | 18 ++-- dummy-plugin/example/e-gen-dummy-schema.cpp | 6 +- dummy-plugin/test/t-dummy-plib.cpp | 4 +- dummy-plugin/test/t-dummy-plugin.cpp | 10 +-- dummy-plugin/test/t-dummy-schema.cpp | 10 +-- local/code/CMakeLists.txt | 14 ++-- ...rScorers.cpp => CommonProviderScorers.cpp} | 10 +-- ...rScorers.hpp => CommonProviderScorers.hpp} | 6 +- local/code/ac/local/Lib.cpp | 20 ++--- local/code/ac/local/Lib.hpp | 12 +-- local/code/ac/local/ModelAssetDesc.hpp | 2 +- local/code/ac/local/ModelLoaderRegistry.cpp | 83 ------------------- local/code/ac/local/ModelLoaderRegistry.hpp | 55 ------------ local/code/ac/local/PluginInfo.hpp | 4 +- local/code/ac/local/PluginInterface.hpp | 6 +- local/code/ac/local/PluginManager.cpp | 18 ++-- local/code/ac/local/PluginManager.hpp | 8 +- local/code/ac/local/PluginPlibUtil.inl | 22 ++--- .../local/{ModelLoader.hpp => Provider.hpp} | 30 +++---- .../{ModelLoaderPtr.hpp => ProviderPtr.hpp} | 4 +- local/code/ac/local/ProviderRegistry.cpp | 83 +++++++++++++++++++ local/code/ac/local/ProviderRegistry.hpp | 55 ++++++++++++ ...delLoaderScorer.hpp => ProviderScorer.hpp} | 12 +-- local/code/ac/local/VtableExports.cpp | 8 +- local/test/CMakeLists.txt | 2 +- ...derRegistry.cpp => t-ProviderRegistry.cpp} | 34 ++++---- ...ict.hpp => GenerateProviderSchemaDict.hpp} | 2 +- 29 files changed, 279 insertions(+), 279 deletions(-) rename dummy-plugin/code/ac/dummy/{DummyLoaderSchema.hpp => DummyProviderSchema.hpp} (98%) rename local/code/ac/local/{CommonModelLoaderScorers.cpp => CommonProviderScorers.cpp} (52%) rename local/code/ac/local/{CommonModelLoaderScorers.hpp => CommonProviderScorers.hpp} (78%) delete mode 100644 local/code/ac/local/ModelLoaderRegistry.cpp delete mode 100644 local/code/ac/local/ModelLoaderRegistry.hpp rename local/code/ac/local/{ModelLoader.hpp => Provider.hpp} (57%) rename local/code/ac/local/{ModelLoaderPtr.hpp => ProviderPtr.hpp} (64%) create mode 100644 local/code/ac/local/ProviderRegistry.cpp create mode 100644 local/code/ac/local/ProviderRegistry.hpp rename local/code/ac/local/{ModelLoaderScorer.hpp => ProviderScorer.hpp} (73%) rename local/test/{t-ModelLoaderRegistry.cpp => t-ProviderRegistry.cpp} (50%) rename schema/code/ac/schema/{GenerateLoaderSchemaDict.hpp => GenerateProviderSchemaDict.hpp} (97%) diff --git a/cmake/ac_local_plugin_util.cmake b/cmake/ac_local_plugin_util.cmake index 6dbc02bc..550ccce8 100644 --- a/cmake/ac_local_plugin_util.cmake +++ b/cmake/ac_local_plugin_util.cmake @@ -88,16 +88,16 @@ void add_@nameSym@_to_ac_local_global_registry(); // Generated file. Do not edit! #pragma once #include "@aclpName@-plib.h" -#include +#include #include -namespace ac::local { class ModelLoaderRegistry; } +namespace ac::local { class ProviderRegistry; } ACLPLIB_@nameSym@_API -const std::vector& get_@nameSym@_model_loaders(); +const std::vector& get_@nameSym@_model_providers(); ACLPLIB_@nameSym@_API -void add_@nameSym@_to_ac_local_registry(ac::local::ModelLoaderRegistry& registry); +void add_@nameSym@_to_ac_local_registry(ac::local::ProviderRegistry& registry); ]=] @ONLY ) @@ -113,15 +113,15 @@ PlibHelper g_helper{ac::@nameSym@::getPluginInterface()}; extern "C" void add_@nameSym@_to_ac_local_global_registry() { - g_helper.addLoadersToGlobalRegistry(); + g_helper.addProvidersToGlobalRegistry(); } -void add_@nameSym@_to_ac_local_registry(ac::local::ModelLoaderRegistry& registry) { - g_helper.addLoadersToRegistry(registry); +void add_@nameSym@_to_ac_local_registry(ac::local::ProviderRegistry& registry) { + g_helper.addProvidersToRegistry(registry); } -const std::vector& get_@nameSym@_model_loaders() { - return g_helper.getLoaders(); +const std::vector& get_@nameSym@_model_providers() { + return g_helper.getProviders(); } ]=] ) diff --git a/dummy-plugin/code/ac/dummy/DummyLoaderSchema.hpp b/dummy-plugin/code/ac/dummy/DummyProviderSchema.hpp similarity index 98% rename from dummy-plugin/code/ac/dummy/DummyLoaderSchema.hpp rename to dummy-plugin/code/ac/dummy/DummyProviderSchema.hpp index ca13282d..8ab02124 100644 --- a/dummy-plugin/code/ac/dummy/DummyLoaderSchema.hpp +++ b/dummy-plugin/code/ac/dummy/DummyProviderSchema.hpp @@ -7,7 +7,7 @@ namespace ac::local::schema { -struct DummyLoader { +struct DummyProvider { static inline constexpr std::string_view id = "dummy"; static inline constexpr std::string_view description = "Dummy inference for tests, examples, and experiments."; diff --git a/dummy-plugin/code/ac/dummy/LocalDummy.cpp b/dummy-plugin/code/ac/dummy/LocalDummy.cpp index 772039a9..8b50121d 100644 --- a/dummy-plugin/code/ac/dummy/LocalDummy.cpp +++ b/dummy-plugin/code/ac/dummy/LocalDummy.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT // #include "LocalDummy.hpp" -#include "DummyLoaderSchema.hpp" +#include "DummyProviderSchema.hpp" #include "Instance.hpp" #include "Model.hpp" @@ -11,7 +11,7 @@ #include #include -#include +#include #include @@ -29,7 +29,7 @@ class DummyInstance final : public Instance { dummy::Instance m_instance; schema::OpDispatcherData m_dispatcherData; public: - using Schema = schema::DummyLoader::InstanceGeneral; + using Schema = schema::DummyProvider::InstanceGeneral; static dummy::Instance::InitParams InitParams_fromDict(Dict&& d) { auto schemaParams = schema::Struct_fromDict(astl::move(d)); @@ -78,7 +78,7 @@ class DummyInstance final : public Instance { class DummyModel final : public Model { std::shared_ptr m_model; public: - using Schema = schema::DummyLoader; + using Schema = schema::DummyProvider; static dummy::Model::Params ModelParams_fromDict(Dict& d) { auto schemaParams = schema::Struct_fromDict(std::move(d)); @@ -103,7 +103,7 @@ class DummyModel final : public Model { } }; -class DummyModelLoader final : public ModelLoader { +class DummyProvider final : public Provider { public: virtual const Info& info() const noexcept override { static Info i = { @@ -147,9 +147,9 @@ class DummyModelLoader final : public ModelLoader { namespace ac::dummy { -std::vector getLoaders() { - std::vector ret; - ret.push_back(std::make_unique()); +std::vector getProviders() { + std::vector ret; + ret.push_back(std::make_unique()); return ret; } @@ -162,7 +162,7 @@ local::PluginInterface getPluginInterface() { ACLP_dummy_VERSION_MAJOR, ACLP_dummy_VERSION_MINOR, ACLP_dummy_VERSION_PATCH }, .init = nullptr, - .getLoaders = getLoaders, + .getProviders = getProviders, }; } diff --git a/dummy-plugin/example/e-gen-dummy-schema.cpp b/dummy-plugin/example/e-gen-dummy-schema.cpp index 4558fc74..a376cb39 100644 --- a/dummy-plugin/example/e-gen-dummy-schema.cpp +++ b/dummy-plugin/example/e-gen-dummy-schema.cpp @@ -1,12 +1,12 @@ // Copyright (c) Alpaca Core // SPDX-License-Identifier: MIT // -#include -#include +#include +#include #include int main() { - auto d = ac::local::schema::generateLoaderSchema(); + auto d = ac::local::schema::generateProviderSchema(); std::cout << d.dump(2) << std::endl; return 0; } diff --git a/dummy-plugin/test/t-dummy-plib.cpp b/dummy-plugin/test/t-dummy-plib.cpp index 7011ddd1..9d7b204c 100644 --- a/dummy-plugin/test/t-dummy-plib.cpp +++ b/dummy-plugin/test/t-dummy-plib.cpp @@ -2,11 +2,11 @@ // SPDX-License-Identifier: MIT // #include -#include +#include #include -struct DummyRegistry : public ac::local::ModelLoaderRegistry { +struct DummyRegistry : public ac::local::ProviderRegistry { DummyRegistry() { add_dummy_to_ac_local_registry(*this); } diff --git a/dummy-plugin/test/t-dummy-plugin.cpp b/dummy-plugin/test/t-dummy-plugin.cpp index d4efb74d..bc6937c2 100644 --- a/dummy-plugin/test/t-dummy-plugin.cpp +++ b/dummy-plugin/test/t-dummy-plugin.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include @@ -16,11 +16,11 @@ struct GlobalFixture { }; GlobalFixture globalFixture; -struct DummyRegistry : public ac::local::ModelLoaderRegistry { +struct DummyRegistry : public ac::local::ProviderRegistry { DummyRegistry() { - auto& loaders = ac::local::Lib::modelLoaderRegistry().loaders(); - for (auto& loader : loaders) { - addLoader(*loader.loader); + auto& providers = ac::local::Lib::providerRegistry().providers(); + for (auto& provider : providers) { + addProvider(*provider.provider); } } }; diff --git a/dummy-plugin/test/t-dummy-schema.cpp b/dummy-plugin/test/t-dummy-schema.cpp index 8da64252..557cd1ab 100644 --- a/dummy-plugin/test/t-dummy-schema.cpp +++ b/dummy-plugin/test/t-dummy-schema.cpp @@ -4,14 +4,14 @@ #include #include #include -#include +#include #include #include #include -#include +#include #include @@ -24,7 +24,7 @@ struct LoadDummyFixture { LoadDummyFixture() : helper(ac::dummy::getPluginInterface()) { - helper.addLoadersToGlobalRegistry(); + helper.addProvidersToGlobalRegistry(); } }; @@ -34,11 +34,11 @@ TEST_CASE("dummy schema") { auto model = ac::local::Lib::loadModel({ .type = "dummy", .name = "synthetic" - }, {}); + }, {}); REQUIRE(!!model); - using Instance = ac::local::schema::DummyLoader::InstanceGeneral; + using Instance = ac::local::schema::DummyProvider::InstanceGeneral; auto instance = Model_createInstance(*model, {.cutoff = 2}); using Interface = ac::local::schema::DummyInterface; diff --git a/local/code/CMakeLists.txt b/local/code/CMakeLists.txt index 477e01e5..76ce3b0a 100644 --- a/local/code/CMakeLists.txt +++ b/local/code/CMakeLists.txt @@ -32,9 +32,9 @@ target_sources(ac-local ac/local/Model.hpp ac/local/Instance.hpp - ac/local/ModelLoader.hpp - ac/local/ModelLoaderPtr.hpp - ac/local/ModelLoaderRegistry.hpp + ac/local/Provider.hpp + ac/local/ProviderPtr.hpp + ac/local/ProviderRegistry.hpp ac/local/PluginInterface.hpp ac/local/PluginInfo.hpp @@ -45,22 +45,22 @@ target_sources(ac-local ac/local/Lib.hpp - ac/local/ModelLoaderScorer.hpp - ac/local/CommonModelLoaderScorers.hpp + ac/local/ProviderScorer.hpp + ac/local/CommonProviderScorers.hpp PRIVATE ac/local/Logging.hpp ac/local/Logging.cpp ac/local/VtableExports.cpp - ac/local/ModelLoaderRegistry.cpp + ac/local/ProviderRegistry.cpp ac/local/PluginManager.cpp ac/local/Session.cpp ac/local/Lib.cpp - ac/local/CommonModelLoaderScorers.cpp + ac/local/CommonProviderScorers.cpp ) install(TARGETS ac-local diff --git a/local/code/ac/local/CommonModelLoaderScorers.cpp b/local/code/ac/local/CommonProviderScorers.cpp similarity index 52% rename from local/code/ac/local/CommonModelLoaderScorers.cpp rename to local/code/ac/local/CommonProviderScorers.cpp index 9b25d4d2..82cc7ebb 100644 --- a/local/code/ac/local/CommonModelLoaderScorers.cpp +++ b/local/code/ac/local/CommonProviderScorers.cpp @@ -1,18 +1,18 @@ // Copyright (c) Alpaca Core // SPDX-License-Identifier: MIT // -#include "CommonModelLoaderScorers.hpp" -#include "ModelLoader.hpp" +#include "CommonProviderScorers.hpp" +#include "Provider.hpp" namespace ac::local { -ModelLoaderScorer::score_t CanLoadScorer::score( - const ModelLoader& loader, +ProviderScorer::score_t CanLoadScorer::score( + const Provider& provider, const PluginInfo*, const ModelAssetDesc& model, const Dict& params ) const noexcept { - return loader.canLoadModel(model, params); + return provider.canLoadModel(model, params); } } // namespace ac::local diff --git a/local/code/ac/local/CommonModelLoaderScorers.hpp b/local/code/ac/local/CommonProviderScorers.hpp similarity index 78% rename from local/code/ac/local/CommonModelLoaderScorers.hpp rename to local/code/ac/local/CommonProviderScorers.hpp index df22d26e..58321f88 100644 --- a/local/code/ac/local/CommonModelLoaderScorers.hpp +++ b/local/code/ac/local/CommonProviderScorers.hpp @@ -2,18 +2,18 @@ // SPDX-License-Identifier: MIT // #pragma once -#include "ModelLoaderScorer.hpp" +#include "ProviderScorer.hpp" namespace ac::local { -struct AC_LOCAL_EXPORT BooleanScorer : public ModelLoaderScorer { +struct AC_LOCAL_EXPORT BooleanScorer : public ProviderScorer { score_t denyScore() const noexcept override final { return 0; } score_t acceptScore() const noexcept override final { return 1; } }; struct AC_LOCAL_EXPORT CanLoadScorer final : public BooleanScorer { score_t score( - const ModelLoader& loader, + const Provider& provider, const PluginInfo* info, const ModelAssetDesc& model, const Dict& params diff --git a/local/code/ac/local/Lib.cpp b/local/code/ac/local/Lib.cpp index e7f42c14..f701b2fb 100644 --- a/local/code/ac/local/Lib.cpp +++ b/local/code/ac/local/Lib.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT // #include "Lib.hpp" -#include "ModelLoaderRegistry.hpp" +#include "ProviderRegistry.hpp" #include "PluginManager.hpp" #include "ModelAssetDesc.hpp" #include @@ -11,24 +11,24 @@ namespace ac::local { namespace { -ModelLoaderRegistry g_modelLoaderRegistry("global"); -PluginManager g_pluginManager(g_modelLoaderRegistry); +ProviderRegistry g_providerRegistry("global"); +PluginManager g_pluginManager(g_providerRegistry); } // namespace -ModelLoaderRegistry& Lib::modelLoaderRegistry() { - return g_modelLoaderRegistry; +ProviderRegistry& Lib::providerRegistry() { + return g_providerRegistry; } -void Lib::addLoader(ModelLoader& loader) { - g_modelLoaderRegistry.addLoader(loader); +void Lib::addProvider(Provider& provider) { + g_providerRegistry.addProvider(provider); } ModelPtr Lib::loadModel(ModelAssetDesc desc, Dict params, ProgressCb cb) { - return g_modelLoaderRegistry.loadModel(astl::move(desc), astl::move(params), astl::move(cb)); + return g_providerRegistry.loadModel(astl::move(desc), astl::move(params), astl::move(cb)); } -ModelPtr Lib::loadModel(const ModelLoaderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb) { - return g_modelLoaderRegistry.loadModel(scorer, astl::move(desc), astl::move(params), astl::move(cb)); +ModelPtr Lib::loadModel(const ProviderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb) { + return g_providerRegistry.loadModel(scorer, astl::move(desc), astl::move(params), astl::move(cb)); } PluginManager& Lib::pluginManager() { diff --git a/local/code/ac/local/Lib.hpp b/local/code/ac/local/Lib.hpp index ee9c2c60..a491b952 100644 --- a/local/code/ac/local/Lib.hpp +++ b/local/code/ac/local/Lib.hpp @@ -12,21 +12,21 @@ #include namespace ac::local { -class ModelLoaderRegistry; -class ModelLoader; -class ModelLoaderScorer; +class ProviderRegistry; +class Provider; +class ProviderScorer; struct ModelAssetDesc; class PluginManager; struct PluginInfo; struct AC_LOCAL_EXPORT Lib { - static ModelLoaderRegistry& modelLoaderRegistry(); + static ProviderRegistry& providerRegistry(); - static void addLoader(ModelLoader& loader); + static void addProvider(Provider& provider); static ModelPtr loadModel(ModelAssetDesc desc, Dict params, ProgressCb cb = {}); - static ModelPtr loadModel(const ModelLoaderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb = {}); + static ModelPtr loadModel(const ProviderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb = {}); static PluginManager& pluginManager(); diff --git a/local/code/ac/local/ModelAssetDesc.hpp b/local/code/ac/local/ModelAssetDesc.hpp index 6f9fe94e..8547f6cb 100644 --- a/local/code/ac/local/ModelAssetDesc.hpp +++ b/local/code/ac/local/ModelAssetDesc.hpp @@ -8,7 +8,7 @@ namespace ac::local { -/// Model asset description. Used by `ModelLoader` to load models. +/// Model asset description. Used by `Provider` to load models. /// @ingroup cpp-local struct ModelAssetDesc { /// Asset (weights) type. May be used by loaders to check whether they can load the model. diff --git a/local/code/ac/local/ModelLoaderRegistry.cpp b/local/code/ac/local/ModelLoaderRegistry.cpp deleted file mode 100644 index c0489072..00000000 --- a/local/code/ac/local/ModelLoaderRegistry.cpp +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Alpaca Core -// SPDX-License-Identifier: MIT -// -#include "ModelLoaderRegistry.hpp" -#include "ModelLoader.hpp" -#include "CommonModelLoaderScorers.hpp" -#include "Logging.hpp" -#include -#include -#include -#include -#include - -namespace ac::local { - -ModelLoaderRegistry::ModelLoaderRegistry(std::string_view name) - : m_name(name) -{ - if (m_name.empty()) { - char hex[20] = "0x"; - auto r = std::to_chars(hex + 2, hex + sizeof(hex), reinterpret_cast(this), 16); - m_name = std::string_view(hex, r.ptr - hex); - } -} - -inline jalog::BasicStream& operator,(jalog::BasicStream& s, const std::vector& vec) { - s, "["; - for (const auto& v : vec) { - s, v, ", "; - } - s, "]"; - return s; -} - -void ModelLoaderRegistry::addLoader(ModelLoader& loader, PluginInfo* plugin) { - [[maybe_unused]] auto& info = loader.info(); - AC_LOCAL_LOG(Info, "Registry ", m_name, " adding loader ", info.name, - "\n vendor: ", info.vendor, - "\n tags: ", info.tags - ); - - m_loaders.push_back({&loader, plugin}); -} - -void ModelLoaderRegistry::removeLoader(ModelLoader& loader) { - astl::erase_first_if(m_loaders, [&](const auto& data) { return data.loader == &loader; }); -} - -ModelLoader* ModelLoaderRegistry::findBestLoader( - const ModelLoaderScorer& scorer, const ModelAssetDesc& desc, const Dict& params -) const { - ModelLoader* best = nullptr; - auto bestScore = scorer.denyScore(); - auto acceptScore = scorer.acceptScore(); - - for (const auto& data : m_loaders) { - auto score = scorer.score(*data.loader, data.plugin, desc, params); - if (score > bestScore) { - best = data.loader; - bestScore = score; - } - if (score >= acceptScore) { - return best; - } - } - - return best; -} - -ModelPtr ModelLoaderRegistry::loadModel(const ModelLoaderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb) const { - if (auto loader = findBestLoader(scorer, desc, params)) { - return loader->loadModel(astl::move(desc), astl::move(params), astl::move(cb)); - } - - ac::throw_ex{} << "No loader found for: " << desc.name; - MSVC_WO_10766806(); -} - -ModelPtr ModelLoaderRegistry::loadModel(ModelAssetDesc desc, Dict params, ProgressCb cb) const { - return loadModel(CanLoadScorer{}, astl::move(desc), astl::move(params), astl::move(cb)); -} - -} // namespace ac::local diff --git a/local/code/ac/local/ModelLoaderRegistry.hpp b/local/code/ac/local/ModelLoaderRegistry.hpp deleted file mode 100644 index 499c71ba..00000000 --- a/local/code/ac/local/ModelLoaderRegistry.hpp +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) Alpaca Core -// SPDX-License-Identifier: MIT -// -#pragma once -#include "export.h" -#include "ModelPtr.hpp" -#include "ModelAssetDesc.hpp" -#include "ProgressCb.hpp" -#include -#include -#include - -/// @defgroup cpp-local C++ Local API -/// C++ API for local inference. - -namespace ac::local { -class ModelLoader; -struct PluginInfo; -class ModelLoaderScorer; - -class AC_LOCAL_EXPORT ModelLoaderRegistry { -public: - ModelLoaderRegistry(std::string_view name = {}); - ModelLoaderRegistry(const ModelLoaderRegistry&) = delete; - ModelLoaderRegistry& operator=(const ModelLoaderRegistry&) = delete; - - const std::string& name() const noexcept { return m_name; } - - struct LoaderData { - ModelLoader* loader; // never null - PluginInfo* plugin; // may be null for loaders that have been added directly - }; - - const std::vector& loaders() const noexcept { return m_loaders; } - - void addLoader(ModelLoader& loader, PluginInfo* plugin = nullptr); - void removeLoader(ModelLoader& loader); - - // find the best loader for the given model description and parameters - // returns nullptr if all loaders rank equal or lower then the denyScore of the scorer - ModelLoader* findBestLoader(const ModelLoaderScorer& scorer, const ModelAssetDesc& desc, const Dict& params) const; - - // utliity functions to directly load the model - - // load model with the first loader which can load it - ModelPtr loadModel(ModelAssetDesc desc, Dict params, ProgressCb cb = {}) const; - - // load model with a scorer to select the best loader - ModelPtr loadModel(const ModelLoaderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb = {}) const; -private: - std::string m_name; - std::vector m_loaders; -}; - -} // namespace ac::local diff --git a/local/code/ac/local/PluginInfo.hpp b/local/code/ac/local/PluginInfo.hpp index 1769f527..5d8e374f 100644 --- a/local/code/ac/local/PluginInfo.hpp +++ b/local/code/ac/local/PluginInfo.hpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT // #pragma once -#include "ModelLoaderPtr.hpp" +#include "ProviderPtr.hpp" #include #include @@ -26,7 +26,7 @@ struct PluginInfo { // plugin-specific raw data void* rawData = nullptr; - std::vector loaders; // provided loaders + std::vector providers; // provided providers }; } // namespace ac::local diff --git a/local/code/ac/local/PluginInterface.hpp b/local/code/ac/local/PluginInterface.hpp index 2585aa9c..f32fe044 100644 --- a/local/code/ac/local/PluginInterface.hpp +++ b/local/code/ac/local/PluginInterface.hpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT // #pragma once -#include "ModelLoaderPtr.hpp" +#include "ProviderPtr.hpp" #include #include @@ -25,8 +25,8 @@ struct PluginInterface { using InitFunc = void(*)(); InitFunc init; // optional init function - using GetLoadersFunc = std::vector(*)(); - GetLoadersFunc getLoaders; // function to get loaders + using GetProvidersFunc = std::vector(*)(); + GetProvidersFunc getProviders; // function to get providers }; } diff --git a/local/code/ac/local/PluginManager.cpp b/local/code/ac/local/PluginManager.cpp index e673e2e0..535ede37 100644 --- a/local/code/ac/local/PluginManager.cpp +++ b/local/code/ac/local/PluginManager.cpp @@ -3,7 +3,7 @@ // #include "PluginManager.hpp" #include "PluginInterface.hpp" -#include "ModelLoaderRegistry.hpp" +#include "ProviderRegistry.hpp" #include "Logging.hpp" #include "Version.hpp" @@ -46,18 +46,18 @@ inline hplugin load_plugin(const char* filename) { namespace ac::local { -PluginManager::PluginManager(ModelLoaderRegistry& registry) +PluginManager::PluginManager(ProviderRegistry& registry) : m_registry(registry) {} PluginManager::~PluginManager() { for (auto& plugin : m_plugins) { - if (!plugin.loaders.empty()) { - for (auto& loader : plugin.loaders) { - m_registry.removeLoader(*loader); + if (!plugin.providers.empty()) { + for (auto& provider : plugin.providers) { + m_registry.removeProvider(*provider); } } - plugin.loaders.clear(); + plugin.providers.clear(); unload_plugin((hplugin)plugin.nativeHandle); } m_plugins.clear(); @@ -226,15 +226,15 @@ const PluginInfo* PluginManager::tryLoadPlugin(const std::string& path, LoadPlug hplugin = nullptr; // release sentry info.rawData = interface.rawData; - info.loaders = interface.getLoaders(); + info.providers = interface.getProviders(); info.tags.reserve(interface.numTags); for (int i = 0; i < interface.numTags; ++i) { info.tags.push_back(interface.tags[i]); } - for (auto& loader : info.loaders) { - m_registry.addLoader(*loader, &info); + for (auto& provider : info.providers) { + m_registry.addProvider(*provider, &info); } cb.onPluginLoaded(info); diff --git a/local/code/ac/local/PluginManager.hpp b/local/code/ac/local/PluginManager.hpp index 0e65918b..1a4311ec 100644 --- a/local/code/ac/local/PluginManager.hpp +++ b/local/code/ac/local/PluginManager.hpp @@ -9,12 +9,12 @@ #include namespace ac::local { -class ModelLoaderRegistry; +class ProviderRegistry; struct PluginInterface; class AC_LOCAL_EXPORT PluginManager { public: - PluginManager(ModelLoaderRegistry& registry); + PluginManager(ProviderRegistry& registry); ~PluginManager(); PluginManager(const PluginManager&) = delete; @@ -22,7 +22,7 @@ class AC_LOCAL_EXPORT PluginManager { static std::string_view pluginPathToName(std::string_view path); - ModelLoaderRegistry& modelLoaderRegistry() const noexcept { return m_registry; } + ProviderRegistry& providerRegistry() const noexcept { return m_registry; } const std::vector& plugins() const noexcept { return m_plugins; } const std::vector& pluginDirs() const noexcept { return m_pluginDirs; } @@ -38,7 +38,7 @@ class AC_LOCAL_EXPORT PluginManager { private: const PluginInfo* tryLoadPlugin(const std::string& path, LoadPluginCb& cb); - ModelLoaderRegistry& m_registry; + ProviderRegistry& m_registry; std::vector m_pluginDirs; diff --git a/local/code/ac/local/PluginPlibUtil.inl b/local/code/ac/local/PluginPlibUtil.inl index 6b1fc0e3..35df0a2d 100644 --- a/local/code/ac/local/PluginPlibUtil.inl +++ b/local/code/ac/local/PluginPlibUtil.inl @@ -4,7 +4,7 @@ // inline file - no include guard #include -#include +#include #include namespace { @@ -14,10 +14,10 @@ struct PlibHelper { PluginInterface m_pluginInterface; PlibHelper(const PluginInterface& pluginInterface) : m_pluginInterface(pluginInterface) {} - std::vector m_loaders; + std::vector m_loaders; bool m_addedToGlobalRegistry = false; - void fillLoaders() { + void fillProviders() { if (!m_loaders.empty()) { // already filled return; @@ -25,27 +25,27 @@ struct PlibHelper { if (m_pluginInterface.init) { m_pluginInterface.init(); } - m_loaders = m_pluginInterface.getLoaders(); + m_loaders = m_pluginInterface.getProviders(); } - void addLoadersToRegistry(ac::local::ModelLoaderRegistry& registry) { - fillLoaders(); + void addProvidersToRegistry(ac::local::ProviderRegistry& registry) { + fillProviders(); for (auto& loader : m_loaders) { - registry.addLoader(*loader); + registry.addProvider(*loader); } } - void addLoadersToGlobalRegistry() { + void addProvidersToGlobalRegistry() { if (m_addedToGlobalRegistry) { // already added return; } - addLoadersToRegistry(ac::local::Lib::modelLoaderRegistry()); + addProvidersToRegistry(ac::local::Lib::providerRegistry()); m_addedToGlobalRegistry = true; } - const auto& getLoaders() { - fillLoaders(); + const auto& getProviders() { + fillProviders(); return m_loaders; } }; diff --git a/local/code/ac/local/ModelLoader.hpp b/local/code/ac/local/Provider.hpp similarity index 57% rename from local/code/ac/local/ModelLoader.hpp rename to local/code/ac/local/Provider.hpp index 477292e5..1c758e47 100644 --- a/local/code/ac/local/ModelLoader.hpp +++ b/local/code/ac/local/Provider.hpp @@ -10,48 +10,48 @@ namespace ac::local { -/// Base class for model loaders. -/// Model loaders are responsible for loading models based on the provided description and parameters. They are typically -/// facades for an underlying inference library. While model loaders can be used on their own, they are typically used -/// via the `ModelLoaderRegistry` class. +/// Base class for local inference providers. +/// Providers are responsible for creating and managing stateful sessions. They are typically +/// facades for an underlying inference library. While providers can be used on their own, they are typically used +/// via the `ProviderRegistry` class. /// @ingroup cpp-local -class AC_LOCAL_EXPORT ModelLoader { +class AC_LOCAL_EXPORT Provider { public: - virtual ~ModelLoader(); + virtual ~Provider(); struct Info { - /// Human-readable name of the loader. - /// Does not necessarily have to be unique across loaders. + /// Human-readable name of the provider. + /// Does not necessarily have to be unique across providers. std::string name; - /// Optional human readable name of the loader vendor. + /// Optional human readable name of the provider vendor. std::string vendor; - /// Schema for the loader. + /// Schema for the provider. Dict schema; - /// Additional tags that can be used to filter loaders + /// Additional tags that can be used to filter providers std::vector tags; /// Additional metadata that can be used to store more structured information Dict metadata; - /// Loader-specific raw data which can be used to store additional information. + /// Provider-specific raw data which can be used to store additional information. /// Use this only as a last resort. You must make sure you know what's in there. void* rawData = nullptr; }; - /// Info of the loader. + /// Info of the provider. virtual const Info& info() const noexcept = 0; /// Check if the model can be loaded - /// This function is used by `ModelLoaderRegistry` to check if the model should be loaded by this loader. + /// This function is used by `ProviderRegistry` to check if the model should be loaded by this provider. /// Keep it as lightweight as possible. virtual bool canLoadModel(const ModelAssetDesc& desc, const Dict& params) const noexcept = 0; /// Load a model based on the provided description and parameters. /// The progress callback is optional and can be used to report the progress of the loading process. - /// The returned model is owned by the caller and is not bound to the loader in any way. + /// The returned model is owned by the caller and is not bound to the provider in any way. virtual ModelPtr loadModel(ModelAssetDesc desc, Dict params, ProgressCb cb) = 0; }; diff --git a/local/code/ac/local/ModelLoaderPtr.hpp b/local/code/ac/local/ProviderPtr.hpp similarity index 64% rename from local/code/ac/local/ModelLoaderPtr.hpp rename to local/code/ac/local/ProviderPtr.hpp index 27f6eeb0..35fe345a 100644 --- a/local/code/ac/local/ModelLoaderPtr.hpp +++ b/local/code/ac/local/ProviderPtr.hpp @@ -2,9 +2,9 @@ // SPDX-License-Identifier: MIT // #pragma once -#include "ModelLoader.hpp" +#include "Provider.hpp" #include namespace ac::local { -using ModelLoaderPtr = std::unique_ptr; +using ProviderPtr = std::unique_ptr; } // namespace ac::local diff --git a/local/code/ac/local/ProviderRegistry.cpp b/local/code/ac/local/ProviderRegistry.cpp new file mode 100644 index 00000000..5a3324bd --- /dev/null +++ b/local/code/ac/local/ProviderRegistry.cpp @@ -0,0 +1,83 @@ +// Copyright (c) Alpaca Core +// SPDX-License-Identifier: MIT +// +#include "ProviderRegistry.hpp" +#include "Provider.hpp" +#include "CommonProviderScorers.hpp" +#include "Logging.hpp" +#include +#include +#include +#include +#include + +namespace ac::local { + +ProviderRegistry::ProviderRegistry(std::string_view name) + : m_name(name) +{ + if (m_name.empty()) { + char hex[20] = "0x"; + auto r = std::to_chars(hex + 2, hex + sizeof(hex), reinterpret_cast(this), 16); + m_name = std::string_view(hex, r.ptr - hex); + } +} + +inline jalog::BasicStream& operator,(jalog::BasicStream& s, const std::vector& vec) { + s, "["; + for (const auto& v : vec) { + s, v, ", "; + } + s, "]"; + return s; +} + +void ProviderRegistry::addProvider(Provider& provider, PluginInfo* plugin) { + [[maybe_unused]] auto& info = provider.info(); + AC_LOCAL_LOG(Info, "Registry ", m_name, " adding provider ", info.name, + "\n vendor: ", info.vendor, + "\n tags: ", info.tags + ); + + m_providers.push_back({&provider, plugin}); +} + +void ProviderRegistry::removeProvider(Provider& provider) { + astl::erase_first_if(m_providers, [&](const auto& data) { return data.provider == &provider; }); +} + +Provider* ProviderRegistry::findBestProvider( + const ProviderScorer& scorer, const ModelAssetDesc& desc, const Dict& params +) const { + Provider* best = nullptr; + auto bestScore = scorer.denyScore(); + auto acceptScore = scorer.acceptScore(); + + for (const auto& data : m_providers) { + auto score = scorer.score(*data.provider, data.plugin, desc, params); + if (score > bestScore) { + best = data.provider; + bestScore = score; + } + if (score >= acceptScore) { + return best; + } + } + + return best; +} + +ModelPtr ProviderRegistry::loadModel(const ProviderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb) const { + if (auto provider = findBestProvider(scorer, desc, params)) { + return provider->loadModel(astl::move(desc), astl::move(params), astl::move(cb)); + } + + ac::throw_ex{} << "No provider found for: " << desc.name; + MSVC_WO_10766806(); +} + +ModelPtr ProviderRegistry::loadModel(ModelAssetDesc desc, Dict params, ProgressCb cb) const { + return loadModel(CanLoadScorer{}, astl::move(desc), astl::move(params), astl::move(cb)); +} + +} // namespace ac::local diff --git a/local/code/ac/local/ProviderRegistry.hpp b/local/code/ac/local/ProviderRegistry.hpp new file mode 100644 index 00000000..a1759fad --- /dev/null +++ b/local/code/ac/local/ProviderRegistry.hpp @@ -0,0 +1,55 @@ +// Copyright (c) Alpaca Core +// SPDX-License-Identifier: MIT +// +#pragma once +#include "export.h" +#include "ModelPtr.hpp" +#include "ModelAssetDesc.hpp" +#include "ProgressCb.hpp" +#include +#include +#include + +/// @defgroup cpp-local C++ Local API +/// C++ API for local inference. + +namespace ac::local { +class Provider; +struct PluginInfo; +class ProviderScorer; + +class AC_LOCAL_EXPORT ProviderRegistry { +public: + ProviderRegistry(std::string_view name = {}); + ProviderRegistry(const ProviderRegistry&) = delete; + ProviderRegistry& operator=(const ProviderRegistry&) = delete; + + const std::string& name() const noexcept { return m_name; } + + struct ProviderData { + Provider* provider; // never null + PluginInfo* plugin; // may be null for providers that have been added directly + }; + + const std::vector& providers() const noexcept { return m_providers; } + + void addProvider(Provider& provider, PluginInfo* plugin = nullptr); + void removeProvider(Provider& provider); + + // find the best provider for the given model description and parameters + // returns nullptr if all providers rank equal or lower then the denyScore of the scorer + Provider* findBestProvider(const ProviderScorer& scorer, const ModelAssetDesc& desc, const Dict& params) const; + + // utliity functions to directly load the model + + // load model with the first provider which can load it + ModelPtr loadModel(ModelAssetDesc desc, Dict params, ProgressCb cb = {}) const; + + // load model with a scorer to select the best provider + ModelPtr loadModel(const ProviderScorer& scorer, ModelAssetDesc desc, Dict params, ProgressCb cb = {}) const; +private: + std::string m_name; + std::vector m_providers; +}; + +} // namespace ac::local diff --git a/local/code/ac/local/ModelLoaderScorer.hpp b/local/code/ac/local/ProviderScorer.hpp similarity index 73% rename from local/code/ac/local/ModelLoaderScorer.hpp rename to local/code/ac/local/ProviderScorer.hpp index 78132f29..4cf1e5a5 100644 --- a/local/code/ac/local/ModelLoaderScorer.hpp +++ b/local/code/ac/local/ProviderScorer.hpp @@ -8,17 +8,17 @@ namespace ac::local { -class ModelLoader; +class Provider; struct PluginInfo; struct ModelAssetDesc; -class AC_LOCAL_EXPORT ModelLoaderScorer { +class AC_LOCAL_EXPORT ProviderScorer { public: using score_t = int; - virtual ~ModelLoaderScorer(); + virtual ~ProviderScorer(); - // ignore the loader if the score is less than or equal to this + // ignore the provider if the score is less than or equal to this virtual score_t denyScore() const noexcept { return std::numeric_limits::min(); } @@ -29,8 +29,8 @@ class AC_LOCAL_EXPORT ModelLoaderScorer { } virtual score_t score( - const ModelLoader& loader, - const PluginInfo* loaderPlugin, + const Provider& provider, + const PluginInfo* providerPlugin, const ModelAssetDesc& model, const Dict& params ) const noexcept = 0; diff --git a/local/code/ac/local/VtableExports.cpp b/local/code/ac/local/VtableExports.cpp index 60349592..5153b9e0 100644 --- a/local/code/ac/local/VtableExports.cpp +++ b/local/code/ac/local/VtableExports.cpp @@ -1,16 +1,16 @@ // Copyright (c) Alpaca Core // SPDX-License-Identifier: MIT // -#include "ModelLoaderScorer.hpp" -#include "ModelLoader.hpp" +#include "ProviderScorer.hpp" +#include "Provider.hpp" #include "Model.hpp" #include "Instance.hpp" // export vtables for classes which only have that namespace ac::local { -ModelLoaderScorer::~ModelLoaderScorer() = default; -ModelLoader::~ModelLoader() = default; +ProviderScorer::~ProviderScorer() = default; +Provider::~Provider() = default; Model::~Model() = default; Instance::~Instance() = default; } // namespace ac::local diff --git a/local/test/CMakeLists.txt b/local/test/CMakeLists.txt index 5aaa09fa..030fc4d8 100644 --- a/local/test/CMakeLists.txt +++ b/local/test/CMakeLists.txt @@ -6,4 +6,4 @@ macro(add_local_test test) endmacro() add_local_test(ProgressCb) -add_local_test(ModelLoaderRegistry) +add_local_test(ProviderRegistry) diff --git a/local/test/t-ModelLoaderRegistry.cpp b/local/test/t-ProviderRegistry.cpp similarity index 50% rename from local/test/t-ModelLoaderRegistry.cpp rename to local/test/t-ProviderRegistry.cpp index 2cb70692..03492c4d 100644 --- a/local/test/t-ModelLoaderRegistry.cpp +++ b/local/test/t-ProviderRegistry.cpp @@ -1,12 +1,12 @@ // Copyright (c) Alpaca Core // SPDX-License-Identifier: MIT // -#include -#include -#include +#include +#include +#include #include -using Info = ac::local::ModelLoader::Info; +using Info = ac::local::Provider::Info; Info LlamaA{ .name = "llama a", @@ -20,9 +20,9 @@ Info WhisperX{ .name = "whisper x", }; -struct TestLoader : public ac::local::ModelLoader { +struct TestProvider : public ac::local::Provider { const Info& m_info; - TestLoader(const Info& info) : m_info(info) {} + TestProvider(const Info& info) : m_info(info) {} virtual const Info& info() const noexcept override { return m_info; } virtual bool canLoadModel(const ac::local::ModelAssetDesc& desc, const ac::Dict&) const noexcept override { return m_info.name.starts_with(desc.type); @@ -32,27 +32,27 @@ struct TestLoader : public ac::local::ModelLoader { } }; -TEST_CASE("ModelLoaderRegistry") { - ac::local::ModelLoaderRegistry registry; +TEST_CASE("ProviderRegistry") { + ac::local::ProviderRegistry registry; ac::local::ModelAssetDesc - llama {.type = "llama", .name = "llama-7b"}, - whisper {.type = "whisper", .name = "whisper-tiny"}; + llama{ .type = "llama", .name = "llama-7b" }, + whisper{ .type = "whisper", .name = "whisper-tiny" }; ac::local::CanLoadScorer s; - CHECK_FALSE(registry.findBestLoader(s, llama, {})); + CHECK_FALSE(registry.findBestProvider(s, llama, {})); - TestLoader llamaA(LlamaA), llamaB(LlamaB), whisperX(WhisperX); + TestProvider llamaA(LlamaA), llamaB(LlamaB), whisperX(WhisperX); - registry.addLoader(llamaA); - registry.addLoader(llamaB); - registry.addLoader(whisperX); + registry.addProvider(llamaA); + registry.addProvider(llamaB); + registry.addProvider(whisperX); - auto la = registry.findBestLoader(s, llama, {}); + auto la = registry.findBestProvider(s, llama, {}); REQUIRE(la); CHECK(&la->info() == &LlamaA); - auto wh = registry.findBestLoader(s, whisper, {}); + auto wh = registry.findBestProvider(s, whisper, {}); REQUIRE(wh); CHECK(&wh->info() == &WhisperX); } diff --git a/schema/code/ac/schema/GenerateLoaderSchemaDict.hpp b/schema/code/ac/schema/GenerateProviderSchemaDict.hpp similarity index 97% rename from schema/code/ac/schema/GenerateLoaderSchemaDict.hpp rename to schema/code/ac/schema/GenerateProviderSchemaDict.hpp index 44bc2d90..c4da2a65 100644 --- a/schema/code/ac/schema/GenerateLoaderSchemaDict.hpp +++ b/schema/code/ac/schema/GenerateProviderSchemaDict.hpp @@ -8,7 +8,7 @@ namespace ac::local::schema { template -Dict generateLoaderSchema() { +Dict generateProviderSchema() { Dict dict; dict["id"] = Schema::id; dict["description"] = Schema::description;