diff --git a/src/kernel/idt.h b/src/kernel/idt.h index 395ca4c..05a9666 100644 --- a/src/kernel/idt.h +++ b/src/kernel/idt.h @@ -2,6 +2,7 @@ /// \file idt.h /// Definitions relating to a CPU's IDT table #include +#include "assert.h" class IDT { @@ -23,17 +24,30 @@ public: /// stacks can be created. void add_ist_entries(); - /// Get the IST entry used by an entry. + /// Get the IST entry used by an entry, clearing it in the process. /// \arg i Which IDT entry to look in /// \returns The IST index used by entry i, or 0 for none - inline uint8_t get_ist(uint8_t i) const { - return m_entries[i].ist; + inline uint8_t get_ist(uint8_t i) { + return __atomic_exchange_n(&m_entries[i].ist, 0, __ATOMIC_SEQ_CST); } - /// Set the IST entry used by an entry. + /// Restore the IST entry used by an entry when done using it. + /// \arg i Which IDT entry to restore + /// \arg ist The IST index for entry i, or 0 for none + inline void return_ist(uint8_t i, uint8_t ist) { + if (!ist) return; + uint8_t expected = 0; + bool result = __atomic_compare_exchange_n( + &m_entries[i].ist, &expected, ist, + false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + kassert(result, "Tried to overwrite a non-zero IST value in IDT"); + } + + /// Set the IST entry used by an entry. This should not be called + /// by interrupt handlers. /// \arg i Which IDT entry to set /// \arg ist The IST index for entry i, or 0 for none - void set_ist(uint8_t i, uint8_t ist) { m_entries[i].ist = ist; } + inline void set_ist(uint8_t i, uint8_t ist) { m_entries[i].ist = ist; } /// Get the IST entries that are used by this table, as a bitmap static uint8_t used_ist_entries(); diff --git a/src/kernel/interrupts.cpp b/src/kernel/interrupts.cpp index 91b3839..fe1750c 100644 --- a/src/kernel/interrupts.cpp +++ b/src/kernel/interrupts.cpp @@ -75,8 +75,6 @@ isr_handler(cpu_state *regs) // this stack IDT &idt = IDT::current(); uint8_t old_ist = idt.get_ist(vector); - if (old_ist) - idt.set_ist(vector, 0); char message[200]; @@ -165,7 +163,7 @@ isr_handler(cpu_state *regs) // Return the IST for this vector to what it was if (old_ist) - idt.set_ist(vector, old_ist); + idt.return_ist(vector, old_ist); *reinterpret_cast(apic_eoi_addr) = 0; }