文章

MIT 6.S081 | 0x06 Multithreading

书接上回,这次的实验是关于多线程的原理和实现,个人感觉这节很有意思,实验的实现代码没几行,不到 2 个小时就把全部任务完成了。

如果我勤快,可能会出一期关于多线程转换实现相关的文章。至于啥时候出就说不准了 :relieved:

:blue_square: Uthread: switching between threads

线程的切换和进程切换做的工作有些相似,都要保存原有线程/进程的寄存器的值,然后再载入新的寄存器的值。为了方便,我就直接把 kernel/proc.hkernel/swtch.S 这两个文件关于 context 的部分复制到 user/uthread.c 里了。

唯一的难点在于创建新线程,在创建线程的时候要对 context 进行初始化,主要有两点工作:

  1. 初始化线程栈。要将 context.sp 指向栈顶,这里注意栈是从高位空间往低位空间增长的
  2. 函数入口。这里我迷惑了一下,不知道怎么执行线程对应的函数。略微检索了一下,其实将函数入口地址存到寄存器 ra 就可以了,在线程执行过程中就会自动跳转到对应的函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
diff --git a/user/uthread.c b/user/uthread.c
index 06349f5..1db73ef 100644
--- a/user/uthread.c
+++ b/user/uthread.c
@@ -10,10 +10,30 @@
 #define STACK_SIZE  8192
 #define MAX_THREAD  4
 
+struct context {
+  uint64 ra;
+  uint64 sp;
+
+  // callee-saved
+  uint64 s0;
+  uint64 s1;
+  uint64 s2;
+  uint64 s3;
+  uint64 s4;
+  uint64 s5;
+  uint64 s6;
+  uint64 s7;
+  uint64 s8;
+  uint64 s9;
+  uint64 s10;
+  uint64 s11;
+};
 
 struct thread {
   char       stack[STACK_SIZE]; /* the thread's stack */
   int        state;             /* FREE, RUNNING, RUNNABLE */
+
+  struct context context;
 };
 struct thread all_thread[MAX_THREAD];
 struct thread *current_thread;
@@ -62,6 +82,7 @@ thread_schedule(void)
      * Invoke thread_switch to switch from t to next_thread:
      * thread_switch(??, ??);
      */
+    thread_switch((uint64)&t->context, (uint64)&current_thread->context);
   } else
     next_thread = 0;
 }
@@ -76,6 +97,9 @@ thread_create(void (*func)())
   }
   t->state = RUNNABLE;
   // YOUR CODE HERE
+  memset(&t->context, 0, sizeof(t->context));
+  t->context.sp = (uint64)t->stack + STACK_SIZE;
+  t->context.ra = (uint64)func;
 }
 
 void 
diff --git a/user/uthread_switch.S b/user/uthread_switch.S
index 5defb12..22e6482 100644
--- a/user/uthread_switch.S
+++ b/user/uthread_switch.S
@@ -8,4 +8,34 @@
 .globl thread_switch
 thread_switch:
   /* YOUR CODE HERE */
+  sd ra, 0(a0)
+  sd sp, 8(a0)
+  sd s0, 16(a0)
+  sd s1, 24(a0)
+  sd s2, 32(a0)
+  sd s3, 40(a0)
+  sd s4, 48(a0)
+  sd s5, 56(a0)
+  sd s6, 64(a0)
+  sd s7, 72(a0)
+  sd s8, 80(a0)
+  sd s9, 88(a0)
+  sd s10, 96(a0)
+  sd s11, 104(a0)
+
+  ld ra, 0(a1)
+  ld sp, 8(a1)
+  ld s0, 16(a1)
+  ld s1, 24(a1)
+  ld s2, 32(a1)
+  ld s3, 40(a1)
+  ld s4, 48(a1)
+  ld s5, 56(a1)
+  ld s6, 64(a1)
+  ld s7, 72(a1)
+  ld s8, 80(a1)
+  ld s9, 88(a1)
+  ld s10, 96(a1)
+  ld s11, 104(a1)
+
   ret    /* return to ra */

:blue_square: Using threads

这个任务算是让我们熟悉 pthread 这个用户线程库的基本使用,非常简单,6 行代码解决。因为需要顾及时间,所以对锁的粒度有一定的要求。notxv6/ph.c 实现了一个极简的 hash table,将这个 table 分了 5 个 buckets,所以我就创建了 5 个锁,分别处理这 5 个 buckets。这样的话,两个线程可以同时处理不同的 bucket 而不占用同一把锁,相对来说在速度上会有一点提升。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
diff --git a/notxv6/ph.c b/notxv6/ph.c
index 82afe76..9d0582d 100644
--- a/notxv6/ph.c
+++ b/notxv6/ph.c
@@ -14,6 +14,7 @@ struct entry {
   struct entry *next;
 };
 struct entry *table[NBUCKET];
+pthread_mutex_t locks[NBUCKET];
 int keys[NKEYS];
 int nthread = 1;
 
@@ -47,6 +48,7 @@ void put(int key, int value)
     if (e->key == key)
       break;
   }
+  pthread_mutex_lock(&locks[i]);
   if(e){
     // update the existing key.
     e->value = value;
@@ -54,7 +56,7 @@ void put(int key, int value)
     // the new is new.
     insert(key, value, &table[i], table[i]);
   }
-
+  pthread_mutex_unlock(&locks[i]);
 }
 
 static struct entry*
@@ -118,6 +120,10 @@ main(int argc, char *argv[])
     keys[i] = random();
   }
 
+  for (int i = 0; i < NBUCKET; i++) {
+    pthread_mutex_init(&locks[i], NULL);
+  }
+
   //
   // first the puts
   //

:blue_square: Barrier

和上面一样,这个任务是让我们学习 pthread_cond_t 的使用,相关的函数用法这里不再给出,唯一要注意的一点是 pthread_cond_wait() 这个函数中,作为参数的 bstate.barrier_mutex 应该先被锁住,这在 pthread_cond_wait() 函数的注释中也有提及。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
diff --git a/notxv6/barrier.c b/notxv6/barrier.c
index 12793e8..bc94a6b 100644
--- a/notxv6/barrier.c
+++ b/notxv6/barrier.c
@@ -30,7 +30,18 @@ barrier()
   // Block until all threads have called barrier() and
   // then increment bstate.round.
   //
-  
+  pthread_mutex_lock(&bstate.barrier_mutex);
+  bstate.nthread++;
+
+  if (bstate.nthread != nthread) {
+    pthread_cond_wait(&bstate.barrier_cond, &bstate.barrier_mutex);
+  } else {
+    bstate.round++;
+    bstate.nthread = 0;
+    pthread_cond_broadcast(&bstate.barrier_cond);
+  }
+  pthread_mutex_unlock(&bstate.barrier_mutex);
+
 }
 
 static void *
本文由作者按照 CC BY 4.0 进行授权