diff --git a/src/user/srv.init/main.cpp b/src/user/srv.init/main.cpp index b8bcd95..57f5961 100644 --- a/src/user/srv.init/main.cpp +++ b/src/user/srv.init/main.cpp @@ -1,6 +1,13 @@ +#include #include "j6/syscalls.h" +#include "init_args.h" + #include "modules.h" +using kernel::init::module; +using kernel::init::module_type; +using kernel::init::module_program; + extern "C" { int main(int, const char **); } @@ -14,7 +21,15 @@ int main(int argc, const char **argv) { j6_system_log("srv.init starting"); - modules::load_all(_arg_modules_phys); + + modules mods = modules::load_modules(_arg_modules_phys, handle_system, handle_self); + + char message[100]; + for (auto &mod : mods.of_type(module_type::program)) { + auto &prog = static_cast(mod); + sprintf(message, " program module '%s' at %lx", prog.filename, prog.base_address); + j6_system_log(message); + } return 0; } diff --git a/src/user/srv.init/modules.cpp b/src/user/srv.init/modules.cpp index 17fe372..eee0382 100644 --- a/src/user/srv.init/modules.cpp +++ b/src/user/srv.init/modules.cpp @@ -3,68 +3,80 @@ #include "j6/errors.h" #include "j6/syscalls.h" -#include "init_args.h" -#include "pointer_manipulation.h" #include "modules.h" -using namespace kernel::init; - -extern j6_handle_t handle_self; -extern j6_handle_t handle_system; - -namespace modules { +using module = kernel::init::module; +using modules_page = kernel::init::modules_page; static const modules_page * -load_page(uintptr_t address) +get_page(const module *mod) +{ + return reinterpret_cast( + reinterpret_cast(mod) & ~0xfffull); +} + +const module * +module_iterator::operator++() +{ + do { + m_mod = offset_ptr(m_mod, m_mod->mod_length); + + if (m_mod->mod_type == type::none) { + // We've reached the end of a page, see if there's another + const modules_page *page = get_page(m_mod); + if (!page->next) { + m_mod = nullptr; + break; + } + + m_mod = page->modules; + } + } + while (m_type != type::none && m_type != m_mod->mod_type); + + return m_mod; +} + +const module * +module_iterator::operator++(int) +{ + const module *tmp = m_mod; + operator++(); + return tmp; +} + +const modules_page * +load_page(uintptr_t address, j6_handle_t system, j6_handle_t self) { j6_handle_t mods_vma = j6_handle_invalid; - j6_status_t s = j6_system_map_phys(handle_system, &mods_vma, address, 0x1000, 0); + j6_status_t s = j6_system_map_phys(system, &mods_vma, address, 0x1000, 0); if (s != j6_status_ok) exit(s); - s = j6_vma_map(mods_vma, handle_self, address); + s = j6_vma_map(mods_vma, self, address); if (s != j6_status_ok) exit(s); - return reinterpret_cast(address); + return reinterpret_cast(address); } -void -load_all(uintptr_t address) +modules +modules::load_modules(uintptr_t address, j6_handle_t system, j6_handle_t self) { - module_framebuffer const *framebuffer = nullptr; - + const module *first = nullptr; while (address) { - const modules_page *page = load_page(address); + const modules_page *page = load_page(address, system, self); char message[100]; - sprintf(message, "srv.init loading %d modules from page at 0x%lx", page->count, address); + sprintf(message, "srv.init found %d modules from page at 0x%lx", page->count, address); j6_system_log(message); - module *mod = page->modules; - size_t count = page->count; - while (count--) { - switch (mod->mod_type) { - case module_type::framebuffer: - framebuffer = reinterpret_cast(mod); - break; - - case module_type::program: - if (mod->mod_flags == module_flags::no_load) - j6_system_log("Loaded program module"); - else - j6_system_log("Non-loaded program module"); - break; - - default: - j6_system_log("Unknown module"); - } - mod = offset_ptr(mod, mod->mod_length); - } - + if (!first) + first = page->modules; address = page->next; } + + return modules {first}; } -} // namespace modules diff --git a/src/user/srv.init/modules.h b/src/user/srv.init/modules.h index 9f0e1fb..e3bfa0b 100644 --- a/src/user/srv.init/modules.h +++ b/src/user/srv.init/modules.h @@ -2,10 +2,61 @@ /// \file modules.h /// Routines for loading initial argument modules -namespace modules { +#include "j6/types.h" +#include "init_args.h" +#include "pointer_manipulation.h" -/// Load all modules -/// \arg address The physical address of the first page of modules -void load_all(uintptr_t address); -} // namespace modules +class module_iterator +{ +public: + using type = kernel::init::module_type; + using module = kernel::init::module; + + module_iterator(const module *m, type t = type::none) : + m_mod {m}, m_type {t} {} + + const module * operator++(); + const module * operator++(int); + + bool operator==(const module* m) const { return m == m_mod; } + bool operator!=(const module* m) const { return m != m_mod; } + bool operator==(const module_iterator &i) const { return i.m_mod == m_mod; } + bool operator!=(const module_iterator &i) const { return i.m_mod != m_mod; } + + const module & operator*() const { return *m_mod; } + operator const module & () const { return *m_mod; } + const module * operator->() const { return m_mod; } + + // Allow iterators to be used in for(:) statments + module_iterator & begin() { return *this; } + module_iterator end() const { return nullptr; } + +private: + module const * m_mod; + type m_type; +}; + +class modules +{ +public: + using type = kernel::init::module_type; + using iterator = module_iterator; + + static modules load_modules( + uintptr_t address, + j6_handle_t system, + j6_handle_t self); + + iterator of_type(type t) const { return iterator {m_root, t}; } + + iterator begin() const { return iterator {m_root}; } + iterator end() const { return nullptr; } + +private: + using module = kernel::init::module; + + modules(const module* root) : m_root {root} {} + + const module *m_root; +};