|
|
@@ -0,0 +1,150 @@
|
|
|
+package com.malk.pro.tenant;
|
|
|
+
|
|
|
+import org.junit.After;
|
|
|
+import org.junit.Test;
|
|
|
+
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.concurrent.CountDownLatch;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
+import java.util.concurrent.atomic.AtomicReference;
|
|
|
+
|
|
|
+import static org.junit.Assert.assertEquals;
|
|
|
+import static org.junit.Assert.assertNotEquals;
|
|
|
+import static org.junit.Assert.assertNull;
|
|
|
+import static org.junit.Assert.assertTrue;
|
|
|
+import static org.junit.Assert.fail;
|
|
|
+
|
|
|
+/**
|
|
|
+ * 单元测试:{@link TenantContext}
|
|
|
+ *
|
|
|
+ * <p>覆盖 set / current / clear / propagate 全部分支,以及 ThreadLocal 跨线程隔离。
|
|
|
+ * 纯静态工具类测试,无需 Spring 上下文。</p>
|
|
|
+ *
|
|
|
+ * <p>来源:add-mjava-pro tasks §7.1。</p>
|
|
|
+ */
|
|
|
+public class TenantContextTest {
|
|
|
+
|
|
|
+ @After
|
|
|
+ public void cleanup() {
|
|
|
+ // 防止单个 test 残留污染下个 test
|
|
|
+ TenantContext.clear();
|
|
|
+ }
|
|
|
+
|
|
|
+ private TenantProfile profile(String tenantId) {
|
|
|
+ return TenantProfile.builder()
|
|
|
+ .tenantId(tenantId)
|
|
|
+ .enabled(true)
|
|
|
+ .vendorCredentials(Collections.emptyMap())
|
|
|
+ .build();
|
|
|
+ }
|
|
|
+
|
|
|
+ // ---------- 基础读写 ----------
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void current_null_when_unset() {
|
|
|
+ assertNull(TenantContext.current());
|
|
|
+ assertNull(TenantContext.currentTenantId());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void set_then_current_returns_same_profile() {
|
|
|
+ TenantProfile p = profile("guangming");
|
|
|
+ TenantContext.set(p);
|
|
|
+ assertEquals(p, TenantContext.current());
|
|
|
+ assertEquals("guangming", TenantContext.currentTenantId());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void clear_removes_profile() {
|
|
|
+ TenantContext.set(profile("shunfeng"));
|
|
|
+ TenantContext.clear();
|
|
|
+ assertNull(TenantContext.current());
|
|
|
+ assertNull(TenantContext.currentTenantId());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void set_overrides_previous_profile() {
|
|
|
+ TenantContext.set(profile("a"));
|
|
|
+ TenantContext.set(profile("b"));
|
|
|
+ assertEquals("b", TenantContext.currentTenantId());
|
|
|
+ }
|
|
|
+
|
|
|
+ // ---------- propagate 切换 + 恢复 ----------
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void propagate_from_null_restores_null() {
|
|
|
+ TenantProfile target = profile("akds");
|
|
|
+ AtomicReference<String> seen = new AtomicReference<>();
|
|
|
+
|
|
|
+ TenantContext.propagate(target, () -> seen.set(TenantContext.currentTenantId()));
|
|
|
+
|
|
|
+ assertEquals("akds", seen.get());
|
|
|
+ assertNull("propagate 结束后 current 必须回归 null", TenantContext.current());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void propagate_from_existing_restores_previous() {
|
|
|
+ TenantProfile prev = profile("prev");
|
|
|
+ TenantProfile target = profile("target");
|
|
|
+ TenantContext.set(prev);
|
|
|
+ AtomicReference<String> seen = new AtomicReference<>();
|
|
|
+
|
|
|
+ TenantContext.propagate(target, () -> seen.set(TenantContext.currentTenantId()));
|
|
|
+
|
|
|
+ assertEquals("target", seen.get());
|
|
|
+ assertEquals("propagate 结束后必须恢复原 profile", prev, TenantContext.current());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void propagate_restores_even_when_runnable_throws() {
|
|
|
+ TenantProfile prev = profile("prev");
|
|
|
+ TenantContext.set(prev);
|
|
|
+
|
|
|
+ try {
|
|
|
+ TenantContext.propagate(profile("crash"), () -> {
|
|
|
+ throw new RuntimeException("boom");
|
|
|
+ });
|
|
|
+ fail("runnable 异常未传播");
|
|
|
+ } catch (RuntimeException expected) {
|
|
|
+ assertEquals("boom", expected.getMessage());
|
|
|
+ }
|
|
|
+
|
|
|
+ assertEquals("runnable 抛异常后 finally 必须恢复 prev", prev, TenantContext.current());
|
|
|
+ }
|
|
|
+
|
|
|
+ // ---------- 跨线程隔离 ----------
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void threadlocal_isolates_two_threads() throws InterruptedException {
|
|
|
+ TenantProfile main = profile("main");
|
|
|
+ TenantContext.set(main);
|
|
|
+
|
|
|
+ AtomicReference<String> workerSeen = new AtomicReference<>();
|
|
|
+ CountDownLatch ready = new CountDownLatch(1);
|
|
|
+ CountDownLatch done = new CountDownLatch(1);
|
|
|
+
|
|
|
+ Thread worker = new Thread(() -> {
|
|
|
+ // 子线程默认看不到主线程的 ThreadLocal
|
|
|
+ TenantProfile beforeSet = TenantContext.current();
|
|
|
+ TenantContext.set(profile("worker"));
|
|
|
+ workerSeen.set(TenantContext.currentTenantId());
|
|
|
+ ready.countDown();
|
|
|
+ try {
|
|
|
+ done.await(2, TimeUnit.SECONDS);
|
|
|
+ } catch (InterruptedException ignored) {
|
|
|
+ } finally {
|
|
|
+ TenantContext.clear();
|
|
|
+ assertNull("worker 见到的初始 context 必须为 null", beforeSet);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ worker.start();
|
|
|
+ assertTrue("worker 未在 2s 内 set", ready.await(2, TimeUnit.SECONDS));
|
|
|
+
|
|
|
+ // 主线程仍是 main,未被 worker 污染
|
|
|
+ assertEquals("main", TenantContext.currentTenantId());
|
|
|
+ assertNotEquals("worker", TenantContext.currentTenantId());
|
|
|
+
|
|
|
+ done.countDown();
|
|
|
+ worker.join(2000);
|
|
|
+ }
|
|
|
+}
|