Skip to content

Commit

Permalink
refactor(local): local dummy session tests, ref #222
Browse files Browse the repository at this point in the history
  • Loading branch information
iboB committed Jan 13, 2025
1 parent 6be669f commit 49e2b61
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 64 deletions.
1 change: 1 addition & 0 deletions api/code/ac/Session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ void Session::resetHandler(SessionHandlerPtr handler, std::unique_ptr<SessionExe
assert(m_handler->m_session == nullptr);
m_handler->m_session = this;
m_handler->m_executor = std::move(executor);
// not necessarily opened
}
}

Expand Down
3 changes: 3 additions & 0 deletions api/code/ac/SessionHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

namespace ac {

SessionHandler::SessionHandler() = default;
SessionHandler::~SessionHandler() = default;

void SessionHandler::postSessionStrandTask(std::function<void()> task) {
m_executor->post(astl::move(task));
}
Expand Down
13 changes: 9 additions & 4 deletions dummy-plugin/code/ac/dummy/LocalDummy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class TFsm {
}
};

class DummySession final : public TFsm<DummySession>, public SessionHandler {
class DummySessionHandler final : public TFsm<DummySessionHandler>, public SessionHandler {
struct StateInitial final : public State {
using State::State;

Expand Down Expand Up @@ -125,7 +125,7 @@ class DummySession final : public TFsm<DummySession>, public SessionHandler {
if (!ret) {
throw_ex{} << "dummy: unknown op: " << f.op;
}
fsm.writeFrame(Frame{"ret", *ret});
fsm.writeFrame(Frame{f.op, *ret});
}

schema::DummyInterface::OpRun::Return on(schema::DummyInterface::OpRun, schema::DummyInterface::OpRun::Params params) {
Expand Down Expand Up @@ -159,7 +159,12 @@ class DummySession final : public TFsm<DummySession>, public SessionHandler {
void pumpInFrames() {
while (sessionHasInFrames()) {
auto f = getSessionInFrame();
on(*f);
try {
on(*f);
}
catch (std::exception& e) {
writeFrame(Frame{"error", e.what()});
}
}
}
void writeFrame(Frame&& t) {
Expand Down Expand Up @@ -310,7 +315,7 @@ class DummyProvider final : public Provider {
}

virtual SessionHandlerPtr createSessionHandler(std::string_view) {
return {};
return std::make_shared<DummySessionHandler>();
}
};

Expand Down
150 changes: 95 additions & 55 deletions dummy-plugin/test/t-dummy.inl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <ac/local/Model.hpp>
#include <ac/local/Instance.hpp>

#include <ac/local/SyncSession.hpp>

#include <doctest/doctest.h>

#include <astl/move.hpp>
Expand All @@ -21,80 +23,118 @@ const ac::local::ModelAssetDesc Model_Desc = {
}
};

std::shared_ptr<ac::SessionHandler> createDummyHandler() {
DummyRegistry d;
std::unique_ptr<ac::local::SyncSession> createTestSession(DummyRegistry& d) {
REQUIRE(d.providers().size() == 1);
return d.providers().front().provider->createSessionHandler({});
auto s = std::make_unique<ac::local::SyncSession>(d.providers().front().provider->createSessionHandler({}));
REQUIRE(s);
REQUIRE(s->valid());
return s;
}

void checkError(ac::local::SyncSession& s, const std::string_view msg) {
auto frame = s.get();
REQUIRE(frame);
CHECK(frame->op == "error");
CHECK(frame->data.get<std::string_view>() == msg);
}

TEST_CASE("bad model") {
DummyRegistry f;
CHECK_THROWS_WITH(
f.loadModel({
.type = "dummy",
.assets = {
{.path = "nope"}
}
}, {}),
"Failed to open file: nope"
);
DummyRegistry d;
auto s = createTestSession(d);
CHECK_FALSE(s->get());
s->put({"nope", {}});
checkError(*s, "dummy: expected 'load' op, got: nope");

s->put({"load", {{"file_path", "nope"}}});
checkError(*s, "Failed to open file: nope");
}

//TEST_CASE("bad model") {
// DummyRegistry f;
// CHECK_THROWS_WITH(
// f.loadModel({
// .type = "dummy",
// .assets = {
// {.path = "nope"}
// }
// }, {}),
// "Failed to open file: nope"
// );
//}

TEST_CASE("bad instance") {
DummyRegistry f;
auto model = f.loadModel(Model_Desc, {});
REQUIRE(model);
CHECK_THROWS_WITH(model->createInstance("nope", {}), "dummy: unknown instance type: nope");
CHECK_THROWS_WITH(model->createInstance("general", {{"cutoff", 40}}),
"Cutoff 40 greater than model size 3");
DummyRegistry d;
auto s = createTestSession(d);

s->put({ "load", {{"file_path", AC_DUMMY_MODEL_SMALL}} });
CHECK_FALSE(s->get());

s->put({ "nope", {} });
checkError(*s, "dummy: expected 'create' op, got: nope");

s->put({ "create", {{"cutoff", 40}} });
checkError(*s, "Cutoff 40 greater than model size 3");
}

TEST_CASE("general") {
DummyRegistry f;
auto model = f.loadModel(Model_Desc, {});
REQUIRE(model);
DummyRegistry d;
auto s = createTestSession(d);

s->put({"load", {{"file_path", AC_DUMMY_MODEL_SMALL}}});
CHECK_FALSE(s->get());

s->put({"create", {}});
CHECK_FALSE(s->get());

s->put({"nope", {}});
checkError(*s, "dummy: unknown op: nope");

auto i = model->createInstance("general", {});
REQUIRE(i);
s->put({ "run", {{"foo", "nope"}}});
checkError(*s, "Required field input is not set");

CHECK_THROWS_WITH(i->runOp("nope", {}), "dummy: unknown op: nope");
s->put({"run", {{"input", {"a", "b"}}}});
auto f = s->get();
REQUIRE(f);
CHECK(f->op == "run");
CHECK(f->data.at("result").get<std::string>() == "a soco b bate");

CHECK_THROWS_WITH(i->runOp("run", {{"foo", "nope"}}), "Required field input is not set");
s->put({"run", {{"input", {"a", "b"}}, {"splice", false}}});
f = s->get();
REQUIRE(f);
CHECK(f->op == "run");
CHECK(f->data.at("result").get<std::string>() == "a b soco bate vira");

auto res = i->runOp("run", {{"input", {"a", "b"}}});
CHECK(res.at("result").get<std::string>() == "a soco b bate");
s->put({"run", {{"input", {"a", "b"}}, {"throw_on", 3}}});
checkError(*s, "Throw on token 3");

res = i->runOp("run", {{"input", {"a", "b"}}, {"splice", false}});
CHECK(res.at("result").get<std::string>() == "a b soco bate vira");
auto s2 = createTestSession(d);

CHECK_THROWS_WITH(i->runOp("run", {{"input", {"a", "b"}}, {"throw_on", 3}}), "Throw on token 3");
s2->put({"load", {{"file_path", AC_DUMMY_MODEL_SMALL}}});
CHECK_FALSE(s2->get());

auto ci = model->createInstance("general", {{"cutoff", 2}});
REQUIRE(ci);
s2->put({"create", {{"cutoff", 2}}});
CHECK_FALSE(s2->get());

res = ci->runOp("run", {{"input", {"a", "b", "c"}}});
CHECK(res.at("result").get<std::string>() == "a soco b bate c soco");
s2->put({"run", {{"input", {"a", "b", "c"}}}});
f = s2->get();
REQUIRE(f);
CHECK(f->op == "run");
CHECK(f->data.at("result").get<std::string>() == "a soco b bate c soco");
}

TEST_CASE("synthetic") {
DummyRegistry f;

std::string tag;
float progress;
auto model = f.loadModel({
.type = "dummy",
.assets = {}
}, {}, [&](const std::string_view t, float p) {
tag = std::string(t);
progress = p;
return true;
});
REQUIRE(model);
CHECK(tag == "synthetic");
CHECK(progress == 0.5f);

auto instance = model->createInstance("general", {});

auto res = instance->runOp("run", {{"input", {"a", "b"}}});
CHECK(res.at("result").get<std::string>() == "a one b two");
DummyRegistry d;
auto s = createTestSession(d);

s->put({"load", {}});
CHECK_FALSE(s->get());

s->put({"create", {}});
CHECK_FALSE(s->get());

s->put({"run", {{"input", {"a", "b"}}}});
auto f = s->get();
REQUIRE(f);
CHECK(f->op == "run");
CHECK(f->data.at("result").get<std::string>() == "a one b two");;
}
7 changes: 3 additions & 4 deletions local/code/ac/local/SyncSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ struct SynSessionExecutor final : public SessionExecutor {
};
}

SyncSession::SyncSession(SessionHandlerPtr handler)
: m_handler(std::move(handler))
{
resetHandler(m_handler, std::make_unique<SynSessionExecutor>());
SyncSession::SyncSession(SessionHandlerPtr handler) {
resetHandler(handler, std::make_unique<SynSessionExecutor>());
m_handler->shOpened();
}
SyncSession::~SyncSession() {
close();
Expand Down
1 change: 0 additions & 1 deletion local/code/ac/local/SyncSession.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class AC_LOCAL_EXPORT SyncSession final : public Session {

void close() override;

SessionHandlerPtr m_handler;
std::optional<Frame> m_inFrame;
std::optional<Frame> m_outFrame;

Expand Down

0 comments on commit 49e2b61

Please sign in to comment.