From 4e73e933db6cdd030597e0f0b52277ef5431287d Mon Sep 17 00:00:00 2001 From: "Justin C. Miller" Date: Tue, 20 Feb 2024 20:51:14 -0800 Subject: [PATCH] [libj6] Update how init args are passed and used Now the init args are a linked list - this also means ld.so can use its own plus those of the program (eg, SLP and VFS handles). __init_libj6 now adds the head of the list to its global init_args structure, and the j6_find_init_handle function can be used to find a handle in those args for a given proto. This fixes situations like the logger using the wrong mailbox for the service locator and never finding the uart driver. --- src/libraries/j6/include/j6/init.h | 14 ++- src/libraries/j6/init.cpp | 29 +++++- src/libraries/libc/arch/amd64/crt/crt0.s | 2 + src/user/drv.uart/main.cpp | 5 +- src/user/drv.uefi_fb/main.cpp | 4 +- src/user/ld.so/main.cpp | 4 +- src/user/srv.init/loader.cpp | 125 +++++++++++++---------- src/user/srv.init/main.cpp | 2 +- src/user/srv.init/start.s | 7 ++ src/user/srv.logger/main.cpp | 5 +- 10 files changed, 125 insertions(+), 72 deletions(-) diff --git a/src/libraries/j6/include/j6/init.h b/src/libraries/j6/include/j6/init.h index 0839372..93bce28 100644 --- a/src/libraries/j6/include/j6/init.h +++ b/src/libraries/j6/include/j6/init.h @@ -33,6 +33,13 @@ struct j6_arg_header { uint32_t size; uint16_t type; + uint16_t reserved; + j6_arg_header *next; +}; + +struct j6_arg_none +{ + add_header(none); }; struct j6_arg_loader @@ -67,14 +74,17 @@ struct j6_arg_handles struct j6_init_args { - uint64_t args[2]; + uint64_t argv[2]; + j6_arg_header *args; }; - /// Find the first handle of the given type held by this process j6_handle_t API j6_find_first_handle(j6_object_type obj_type); +/// Find the first handle tagged with the given proto in the process init args +j6_handle_t API j6_find_init_handle(uint64_t proto); + /// Get the init args const j6_init_args * j6_get_init_args(); diff --git a/src/libraries/j6/init.cpp b/src/libraries/j6/init.cpp index 1befef1..e28aaf7 100644 --- a/src/libraries/j6/init.cpp +++ b/src/libraries/j6/init.cpp @@ -12,7 +12,7 @@ namespace { constexpr size_t static_arr_count = 32; j6_handle_descriptor handle_array[static_arr_count]; - j6_init_args init_args; + j6_init_args init_args = { 0, 0, 0 }; } // namespace j6_handle_t @@ -36,6 +36,26 @@ j6_find_first_handle(j6_object_type obj_type) return j6_handle_invalid; } +j6_handle_t +j6_find_init_handle(uint64_t proto) +{ + j6_arg_header *arg = init_args.args; + while (arg) { + if (arg->type == j6_arg_type_handles) { + j6_arg_handles *harg = reinterpret_cast(arg); + for (unsigned i = 0; i < harg->nhandles; ++i) { + j6_arg_handle_entry &ent = harg->handles[i]; + if (ent.proto == proto) + return ent.handle; + } + } + arg = arg->next; + } + + return j6_handle_invalid; +} + + const j6_init_args * API j6_get_init_args() { @@ -43,10 +63,11 @@ j6_get_init_args() } extern "C" void API -__init_libj6(uint64_t arg0, uint64_t arg1) +__init_libj6(uint64_t argv0, uint64_t argv1, j6_arg_header *args) { - init_args.args[0] = arg0; - init_args.args[1] = arg1; + init_args.argv[0] = argv0; + init_args.argv[1] = argv1; + init_args.args = args; } diff --git a/src/libraries/libc/arch/amd64/crt/crt0.s b/src/libraries/libc/arch/amd64/crt/crt0.s index e5a5414..299e8cf 100644 --- a/src/libraries/libc/arch/amd64/crt/crt0.s +++ b/src/libraries/libc/arch/amd64/crt/crt0.s @@ -18,6 +18,8 @@ global _libc_crt0_start:function (_libc_crt0_start.end - _libc_crt0_start) _start: _libc_crt0_start: + mov rdx, [rsp] ; grab args pointer + push 0 ; Add null frame push 0 mov rbp, rsp diff --git a/src/user/drv.uart/main.cpp b/src/user/drv.uart/main.cpp index f8677f3..b9ffa40 100644 --- a/src/user/drv.uart/main.cpp +++ b/src/user/drv.uart/main.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -42,7 +41,7 @@ main(int argc, const char **argv) j6_handle_t event = j6_handle_invalid; j6_status_t result = j6_status_ok; - g_handle_sys = j6_find_first_handle(j6_object_type_system); + g_handle_sys = j6_find_init_handle(0); if (g_handle_sys == j6_handle_invalid) return 1; @@ -68,7 +67,7 @@ main(int argc, const char **argv) static constexpr size_t buffer_size = 512; char buffer[buffer_size]; - j6_handle_t slp = j6_find_first_handle(j6_object_type_mailbox); + j6_handle_t slp = j6_find_init_handle(j6::proto::sl::id); if (slp == j6_handle_invalid) return 1; diff --git a/src/user/drv.uefi_fb/main.cpp b/src/user/drv.uefi_fb/main.cpp index 27c9b30..cbd2cda 100644 --- a/src/user/drv.uefi_fb/main.cpp +++ b/src/user/drv.uefi_fb/main.cpp @@ -29,7 +29,7 @@ main(int argc, const char **argv, const char **env) using bootproto::devices::fb_layout; const j6_init_args *init = j6_get_init_args(); - const uefi_fb *fb = reinterpret_cast(init->args[0]); + const uefi_fb *fb = reinterpret_cast(init->argv[0]); if (!fb || !fb->framebuffer) { j6::syslog("fb driver didn't find a framebuffer, exiting"); @@ -45,7 +45,7 @@ main(int argc, const char **argv, const char **env) j6_vm_flag_write_combine | j6_vm_flag_mmio; - j6_handle_t sys = j6_find_first_handle(j6_object_type_system); + j6_handle_t sys = j6_find_init_handle(0); if (sys == j6_handle_invalid) return 1; diff --git a/src/user/ld.so/main.cpp b/src/user/ld.so/main.cpp index b1893e5..63a839d 100644 --- a/src/user/ld.so/main.cpp +++ b/src/user/ld.so/main.cpp @@ -18,7 +18,7 @@ ldso_init(j6_arg_header *stack_args, uintptr_t *got) j6_arg_handles *arg_handles = nullptr; j6_arg_header *arg = stack_args; - while (arg && arg->type != j6_arg_type_none) { + while (arg) { switch (arg->type) { case j6_arg_type_loader: @@ -33,7 +33,7 @@ ldso_init(j6_arg_header *stack_args, uintptr_t *got) break; } - arg = util::offset_pointer(arg, arg->size); + arg = arg->next; } if (!arg_loader) { diff --git a/src/user/srv.init/loader.cpp b/src/user/srv.init/loader.cpp index be56e8f..ea83225 100644 --- a/src/user/srv.init/loader.cpp +++ b/src/user/srv.init/loader.cpp @@ -25,6 +25,47 @@ static util::xoroshiro256pp rng {0x123456}; inline uintptr_t align_up(uintptr_t a) { return ((a-1) & ~(MiB-1)) + MiB; } +class stack_pusher +{ +public: + stack_pusher(uint8_t *local_top, uintptr_t child_top) : + m_local_top {local_top}, m_child_top {child_top}, m_used {0}, m_last_arg {0} { + memset(local_top - 4096, 0, 4096); // Zero top page + } + + template + T * push(size_t extra = 0) { + m_used += sizeof(T) + extra; + m_used = (m_used + (A-1ull)) & ~(A-1ull); + return reinterpret_cast(local_pointer()); + } + + template + T * push_arg(size_t extra = 0) { + T * arg = push(extra); + arg->header.size = sizeof(T) + extra; + arg->header.type = T::type_id; + arg->header.next = reinterpret_cast(m_last_arg); + m_last_arg = child_pointer(); + return arg; + } + + void push_current_pointer() { + uintptr_t addr = child_pointer(); + uintptr_t *ptr = push(); + *ptr = addr; + } + + uint8_t * local_pointer() { return m_local_top - m_used; } + uintptr_t child_pointer() { return m_child_top - m_used; } + +private: + uint8_t *m_local_top; + uintptr_t m_child_top; + size_t m_used; + uintptr_t m_last_arg; +}; + j6_handle_t map_phys(j6_handle_t sys, uintptr_t phys, size_t len, j6_vm_flags flags) { @@ -40,31 +81,6 @@ map_phys(j6_handle_t sys, uintptr_t phys, size_t len, j6_vm_flags flags) return vma; } -void -stack_push_sentinel(uint8_t *&stack) -{ - static constexpr size_t size = sizeof(j6_arg_header); - - stack -= size; - memset(stack, 0, size); - j6_arg_header *header = reinterpret_cast(stack); - header->type = j6_arg_type_none; - header->size = size; -} - -template T * -stack_push(uint8_t *&stack, size_t extra) -{ - size_t len = sizeof(T) + extra; - size_t size = (len + 7) & ~7ull; - stack -= size; - memset(stack, 0, sizeof(T)); - T * arg = reinterpret_cast(stack); - arg->header.type = T::type_id; - arg->header.size = len; - return arg; -} - uintptr_t load_program_into(j6_handle_t proc, elf::file &file, uintptr_t image_base, const char *path) { @@ -205,50 +221,51 @@ load_program( return false; } - uint8_t *stack_orig = reinterpret_cast(stack_addr + stack_size); - uint8_t *stack = stack_orig; - memset(stack - 4096, 0, 4096); // Zero top page - stack_push_sentinel(stack); + stack_pusher stack { + reinterpret_cast(stack_addr + stack_size), + stack_top, + }; + + // Push program's arg sentinel + stack.push_arg(); + + j6_arg_handles *handles_arg = stack.push_arg(3 * sizeof(j6_arg_handle_entry)); + handles_arg->nhandles = 3; + handles_arg->handles[0].handle = sys; + handles_arg->handles[0].proto = 0; + handles_arg->handles[1].handle = slp; + handles_arg->handles[1].proto = j6::proto::sl::id; + handles_arg->handles[2].handle = vfs; + handles_arg->handles[2].proto = j6::proto::vfs::id; if (arg) { - size_t data_size = arg->bytes - sizeof(module); - j6_arg_driver *driver_arg = stack_push(stack, data_size); + size_t data_size = arg->bytes - sizeof(*arg); + j6_arg_driver *driver_arg = stack.push_arg(data_size); driver_arg->device = arg->type_id; const uint8_t *arg_data = arg->data(); memcpy(driver_arg->data, arg_data, data_size); } + // Add an aligned pointer to the program's args list + stack.push_current_pointer(); + uintptr_t entrypoint = program_elf.entrypoint() + program_image_base; if (dyn) { - stack_push_sentinel(stack); - j6_arg_loader *loader_arg = stack_push(stack, 0); - const elf::file_header *h = program_elf.header(); + // Push loaders's arg sentinel + stack.push_arg(); + + j6_arg_loader *loader_arg = stack.push_arg(); loader_arg->image_base = program_image_base; + loader_arg->entrypoint = program_elf.entrypoint(); // ld.so will offset the entrypoint, don't do it here. const elf::section_header *got_section = program_elf.get_section_by_name(".got.plt"); if (got_section) loader_arg->got = reinterpret_cast(program_image_base + got_section->addr); - // The dynamic linker will offset the entrypoint, don't do it here. - loader_arg->entrypoint = program_elf.entrypoint(); - - j6_arg_handles *handles_arg = stack_push(stack, 2 * sizeof(j6_arg_handle_entry)); - handles_arg->nhandles = 2; - handles_arg->handles[0].handle = slp; - handles_arg->handles[0].proto = j6::proto::sl::id; - handles_arg->handles[1].handle = vfs; - handles_arg->handles[1].proto = j6::proto::vfs::id; - - // Align the stack to be one word short of 16-byte aligned, so - // that the arg address will be aligned when pushed - while ((reinterpret_cast(stack) & 0xf) != 0x8) --stack; - - // Push the args list address itself - stack -= sizeof(uintptr_t); - uintptr_t *args_addr = reinterpret_cast(stack); - *args_addr = stack_top - (stack_orig - reinterpret_cast(handles_arg)); + // Add an aligned pointer to the loaders's args list + stack.push_current_pointer(); uintptr_t ldso_image_base = (eop & ~(MiB-1)) + MiB; @@ -290,9 +307,7 @@ load_program( } j6_handle_t thread = j6_handle_invalid; - uintptr_t target_stack = stack_top - (stack_orig - stack); - target_stack &= ~0xfull; // Align to 16-byte stack - res = j6_thread_create(&thread, proc, target_stack, entrypoint, program_image_base, 0); + res = j6_thread_create(&thread, proc, stack.child_pointer(), entrypoint, program_image_base, 0); if (res != j6_status_ok) { j6::syslog(" ** error loading program '%s': creating thread: %lx", path, res); return false; diff --git a/src/user/srv.init/main.cpp b/src/user/srv.init/main.cpp index 03d5e18..ba55a49 100644 --- a/src/user/srv.init/main.cpp +++ b/src/user/srv.init/main.cpp @@ -75,7 +75,7 @@ main(int argc, const char **argv, const char **env) return s; const j6_init_args *initp = j6_get_init_args(); - uintptr_t modules_addr = initp->args[0]; + uintptr_t modules_addr = initp->argv[0]; std::vector mods; load_modules(modules_addr, sys, 0, mods); diff --git a/src/user/srv.init/start.s b/src/user/srv.init/start.s index 4ab0c51..d9f20fe 100644 --- a/src/user/srv.init/start.s +++ b/src/user/srv.init/start.s @@ -18,5 +18,12 @@ _start: ; No parent process exists to have created init's stack, so we create a ; stack in BSS and assign that to be init's first stack mov rsp, init_stack_top + + ; Push a fake j6_arg_none + push 0x00 ; pad for 16-byte alignment + push 0x00 ; no next arg + push 0x10 ; size 16 bytes, type 0 (none) + push rsp + jmp _libc_crt0_start .end: diff --git a/src/user/srv.logger/main.cpp b/src/user/srv.logger/main.cpp index 0f1f44c..d02af13 100644 --- a/src/user/srv.logger/main.cpp +++ b/src/user/srv.logger/main.cpp @@ -95,12 +95,11 @@ main(int argc, const char **argv) { j6_log("logging server starting"); - - g_handle_sys = j6_find_first_handle(j6_object_type_system); + g_handle_sys = j6_find_init_handle(0); if (g_handle_sys == j6_handle_invalid) return 1; - j6_handle_t slp = j6_find_first_handle(j6_object_type_mailbox); + j6_handle_t slp = j6_find_init_handle(j6::proto::sl::id); if (g_handle_sys == j6_handle_invalid) return 2;