1.context是什么

go1.7才引入context,译作“上下文”,实际也叫goroutine 的上下文,包含 goroutine 的运行状态、环境、现场等信息、context 主要用来在 goroutine 之间传递上下文信息,包括:取消信号、超时时间、截止时间、k-v 等。与WaitGroup最大的不同点是context对于派生goroutine有更强的控制力,它可以控制多级的goroutine

随着 context 包的引入,标准库中很多接口加上了 context 参数,例如 database/sql 包、http包。context 几乎成为了并发控制和超时控制的标准做法,由于context的源码里用到了大量的mutex锁用于保护子级的context,且采用链式调用的方式,所以它是并发安全的

2.context接口的实现

context接口只定义了4种方法

type Context interface {
    Deadline() (deadline time.Time, ok bool)
    Done() <-chan struct{}
    Err() error
    Value(key interface{}) interface{}
}
  • Deadline 返回此上下文完成的工作的截止时间,未设置截止日期时,返回 ok==false,对 Deadline 的连续调用返回相同的结果

  • Done 返回一个channel,可以表示 context 被取消的信号,这是一个只读的channel,当这个 channel 被关闭时,说明 context 被取消了,而且读一个关闭的 channel 会读出相应类型的零值(channel对应的零值是nil)。常用在select-case语句中,如case <-context.Done():

  • Err 描述context关闭的原因,由context实现控制,不需要用户设置,例如是被取消,还是超时,主动取消的就返回context canceled,因超时关闭就返回context deadline exceeded

  • Value 用于在树状分布的goroutine间传递信息,根据key值查询map中的value

3.实现context接口的几种结构体

整体类图

3.1 emptyCtx

type emptyCtx int

context包中定义了一个空的context, 名为emptyCtx,用于context的根节点,空的context只是简单的实现了Context,本身不包含任何值,仅用于其他context的父节点

func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
	return
}
func (*emptyCtx) Done() <-chan struct{} {
	return nil
}
func (*emptyCtx) Err() error {
	return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
	return nil
}
func (e *emptyCtx) String() string {
	switch e {
	case background:
		return "context.Background"
	case todo:
		return "context.TODO"
	}
	return "unknown empty Context"
}
var (
	background = new(emptyCtx)
	todo       = new(emptyCtx)
)
func Background() Context {
	return background
}
func TODO() Context {
	return todo
}

emptyCtx是一个int类型的变量,但实现了context的接口。emptyCtx没有超时时间,不能取消,也不能存储任何额外信息,所以emptyCtx用来作为context树的根节点

  • background 通常用在 main 函数中,作为所有 context 的根节点

  • todo 通常用在并不知道传递什么context的情形,相当于用 todo 占个位子,最终要换成其他 context

3.2 cancelCtx

这是一个可以取消的context

type canceler interface {
	cancel(removeFromParent bool, err error)
	Done() <-chan struct{}
}
type cancelCtx struct {
	Context
	mu       sync.Mutex           
	done     atomic.Value         
	children map[canceler]struct{}
	err      error                
}

cancelCtx将接口 Context 作为它的一个匿名字段,因此可以被看成是一个 Context,同时cancelCtx实现了 canceler 接口。children中记录了由此context派生的所有child,此context被cancel时会把其中的所有child都cancel掉,cancelCtx与deadline和value无关

func (c *cancelCtx) Done() <-chan struct{} {
	d := c.done.Load()
	if d != nil {
		return d.(chan struct{})
	}
	c.mu.Lock()
	defer c.mu.Unlock()
	d = c.done.Load()
	if d == nil {
		d = make(chan struct{})
		c.done.Store(d)
	}
	return d.(chan struct{})
}
func (c *cancelCtx) Err() error {
	c.mu.Lock()
	err := c.err
	c.mu.Unlock()
	return err
}
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
	if err == nil {
		panic("context: internal error: missing cancel error")
	}
	c.mu.Lock()
	if c.err != nil {
		c.mu.Unlock()
		return // already canceled
	}
	c.err = err
	d, _ := c.done.Load().(chan struct{})
	if d == nil {
		c.done.Store(closedchan)
	} else {
		close(d)
	}
	for child := range c.children {
		// NOTE: acquiring the child's lock while holding parent's lock.
		child.cancel(false, err)
	}
	c.children = nil
	c.mu.Unlock()

	if removeFromParent {
		removeChild(c.Context, c)
	}
}
  • Done 返回一个只读的channel,而且没有地方向这个 channel 里面写数据,所以直接读这个channel协程会被阻塞住,一般通过搭配 select 来使用,一旦关闭,就会立即读出对应类型的零值

  • Err 则是返回对应的错误类型

  • cancel 关闭 c.done,取消 c 的每个子节点,如果 removeFromParent 为真,则从其父节点的子节点中删除 c,总体来说就是删除自己和其后代,不会影响到父节点和其它分支的节点,这里删除子节点调用的是removeChild,可见调用了delete

    func removeChild(parent Context, child canceler) {
    	p, ok := parentCancelCtx(parent)
    	if !ok {
    		return
    	}
    	p.mu.Lock()
    	if p.children != nil {
    		delete(p.children, child)
    	}
    	p.mu.Unlock()
    }
    

有一个WithCancel方法,会暴露给写代码的人调用:

func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
	if parent == nil {
		panic("cannot create context from nil parent")
	}
	c := newCancelCtx(parent) //①
	propagateCancel(parent, &c) //②
	return &c, func() { c.cancel(true, Canceled) } //③
}

具体实现分三步:①初始化一个cancelCtx实例,②如果父节点也可以被cancel,将cancelCtx实例添加到其父节点的children中,③返回cancelCtx实例和cancel方法

第②步调用的函数propagateCancel代码逻辑可以细分为3步:

  • 如果父节点支持cancel,则父节点有children成员,可以把新context添加到children里
  • 如果父节点不支持cancel,继续向上查询直到找到一个支持cancel的节点,把新context添加到children里
  • 如果所有的父节点均不支持cancel,则启动一个协程等待父节点结束,然后再把当前context结束

同时可见context的特点是:控制是从上至下的,查找是从下至上的。

func propagateCancel(parent Context, child canceler) {
   done := parent.Done()
   if done == nil {
      return // parent is never canceled
   }

   select {
   case <-done:
      // parent is already canceled
      child.cancel(false, parent.Err())
      return
   default:
   }

   if p, ok := parentCancelCtx(parent); ok {
      p.mu.Lock()
      if p.err != nil {
         // parent has already been canceled
         child.cancel(false, p.err)
      } else {
         if p.children == nil {
            p.children = make(map[canceler]struct{})
         }
         p.children[child] = struct{}{}
      }
      p.mu.Unlock()
   } else {
      atomic.AddInt32(&goroutines, +1)
      go func() {
         select {
         case <-parent.Done():
            child.cancel(false, parent.Err())
         case <-child.Done():
         }
      }()
   }
}

3.3 timerCtx

type timerCtx struct {
	cancelCtx
	timer *time.Timer 
	deadline time.Time
}
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
	return c.deadline, true
}
func (c *timerCtx) cancel(removeFromParent bool, err error) {
	c.cancelCtx.cancel(false, err)
	if removeFromParent {
		// Remove this timerCtx from its parent cancelCtx's children.
		removeChild(c.cancelCtx.Context, c)
	}
	c.mu.Lock()
	if c.timer != nil {
		c.timer.Stop()
		c.timer = nil
	}
	c.mu.Unlock()
}

timerCtx是基于cancelCtx的,此外多了timer和deadline,timer指最长存活时间,比如将在多少秒后结束,deadline表示最后期限,需要指定具体的截止日期,由此衍生出了WithDeadline()和WithTimeout()函数

3.4 vauleCtx

type valueCtx struct {
	Context
	key, val interface{}
}

valueCtx在Context基础上增加了一个key-value对,用于在各级协程间传递一些数据,valueCtx既不需要cancel,也不需要deadline,只需要实现Value()接口

func (c *valueCtx) Value(key interface{}) interface{} {
	if c.key == key {
		return c.val
	}
	return c.Context.Value(key)
}

因此有了WithVaule函数

func WithValue(parent Context, key, val interface{}) Context {
   if parent == nil {
      panic("cannot create context from nil parent")
   }
   if key == nil {
      panic("nil key")
   }
   if !reflectlite.TypeOf(key).Comparable() {
      panic("key is not comparable")
   }
   return &valueCtx{parent, key, val}
}

WithVaule可以用来设置键值对,且由代码return &valueCtx{parent, key, val}可知,每次调用WithValue函数都会基于当前context衍生一个新的子context,而不是在原来的context结构体上直接添加,由此形成了一条context链,获取键值的过程也是层层向上调用直到最终的根节点,若找到了key则会返回值,否则找到最终的emptyCtx返回nil


从源代码中可以看出使用了大量的锁,而且context存取值采用链式的方式,保证了执行过程中的并发安全

4.context的使用

WithCancel、WithTimeout、WithDeadline三者都返回一个可取消的 context 实例,和cancel()函数

context 的实例之间存在父子关系,当父亲取消或者超时,所有派生的子context 都被取消或者超时

当找 key 的时候,子 context 先看自己有没有,没有则去祖先里面找控制是从上至下的,查找是从下至上的

4.1 WithCancel

func ExampleWithCancel() {
	gen := func(ctx context.Context) <-chan int {
		dst := make(chan int)
		n := 1
		go func() {
			for {
				select {
				case <-ctx.Done():
					return // 接收到cancel的信号后进入这个循环并返回,防止goroutine泄漏
				case dst <- n:
					n++
				}
			}
		}()
		return dst //根据 <-chan int 可知,只能返回单向chan,否则会阻塞
	}

	ctx, cancel := context.WithCancel(context.Background())
    defer cancel() // 当break之后没有其他要执行了,就会执行cancel()

	for n := range gen(ctx) {
		fmt.Println(n)
		if n == 5 {
			break
		}
	}
	// Output:
	// 1
	// 2
	// 3
	// 4
	// 5
}

4.2 WithDeadline

func ExampleWithDeadline() {
	d := time.Now().Add(100 * time.Millisecond) //0.1秒后过期,填的是具体某个时间
	ctx, cancel := context.WithDeadline(context.Background(), d)
	defer cancel()

	select {
	case <-time.After(1 * time.Second): // 若1秒后还没过期则执行这里
		fmt.Println("overslept")
    case <-ctx.Done():	//监测cancel()
		fmt.Println(ctx.Err())
	}

	// Output:
	// context deadline exceeded
}

4.3 WithTimeout

func ExampleWithTimeout() {
	ctx, cancel := context.WithTimeout(context.Background(), 100 * time.Millisecond) //0.1秒后过期,填的是时间段长短
	defer cancel()

	select {
	case <-time.After(1 * time.Second):
		fmt.Println("overslept")
	case <-ctx.Done():
		fmt.Println(ctx.Err()) // prints "context deadline exceeded"
	}

	// Output:
	// context deadline exceeded
}

WithTimeout和WithDeadline其实非常相似,可以看下面的源码,WithTimeout直接调用WithDeadline,只是用法不一样

func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
	return WithDeadline(parent, time.Now().Add(timeout))
}

就控制来说,父context可以控制子context,如下例:

func TestControl(t *testing.T) {
	ctx1, cancelFunc1 := context.WithTimeout(context.Background(), time.Second)
	defer cancelFunc1()
	ctx2, cancelFunc2 := context.WithTimeout(ctx1, 3*time.Second)
	defer cancelFunc2()
	go func() {
		t1 := time.Now().Second()
		<-ctx2.Done()
		fmt.Println("timeout")
		fmt.Println(time.Now().Second()-t1)
	}()
	time.Sleep(2 * time.Second)
    
    // Output
    // timeout
	// 1
    
    //由结果可知前后相差了一秒,说明父context时间到期后,就直接cancel了,不会等到子context的过期时间,
    //即父context可以控制子context,但是如果子context的时间比父context的时间更短,则会优先执行子context的cancel
}

4.4 WithValue

func TestContext(t *testing.T) {
	ctx := context.WithValue(context.Background(), "key1", "val1")	//赋值
	value := ctx.Value("key1")	//取值
	fmt.Println(value)
    
    // Output:
    // val1
}

父context无法拿到子context的结点内容,只能由子结点拿父结点的,如下源码:

func (c *valueCtx) Value(key any) any {
	if c.key == key {	//先找自己的
		return c.val
	}
	return value(c.Context, key)	//没有找到就找父节点的
}

如下例子:

func TestContext(t *testing.T) {
	ctx := context.WithValue(context.Background(), "key1", "val1")
	value := ctx.Value("key1")
	fmt.Println(value)

	subCtx := context.WithValue(ctx, "", "")
	subValue := subCtx.Value("key1")
	fmt.Println(subValue)
    
    // Output:
    // val1
    // val1		
}

4.5 异步调用链

下面这个例子的流程图如下

假设一个例子,genGreeting在放弃调用locale之前等待一秒——超时时间为1秒,如果printGreeting不成功,就取消对printFare的调用

可以看到系统输出工作正常,由于local设置至少需要运行一分钟,因此genGreeting将始终超时,这意味着main会始终取消printFarewell下面的调用链,

func main() {
    var wg sync.WaitGroup
    ctx, cancel := context.WithCancel(context.Background())
    defer cancel()
    wg.Add(1)
    go func() {
        defer wg.Done()

        if err := printGreeting(ctx); err != nil {
            fmt.Printf("cannot print greeting: %v\n", err)
            cancel()
        }
    }()

    wg.Add(1)
    go func() {
        defer wg.Done()
        if err := printFarewell(ctx); err != nil {
            fmt.Printf("cannot print farewell: %v\n", err)
        }
    }()

    wg.Wait()
}

func printGreeting(ctx context.Context) error {
    greeting, err := genGreeting(ctx)
    if err != nil {
        return err
    }
    fmt.Printf("%s world!\n", greeting)
    return nil
}

func printFarewell(ctx context.Context) error {
    farewell, err := genFarewell(ctx)
    if err != nil {
        return err
    }
    fmt.Printf("%s world!\n", farewell)
    return nil
}

func genGreeting(ctx context.Context) (string, error) {
    ctx, cancel := context.WithTimeout(ctx, 1*time.Second) //只等1秒钟就取消
    defer cancel()

    switch locale, err := locale(ctx); {
    case err != nil:
        return "", err
    case locale == "EN/US":
        return "hello", nil
    }
    return "", fmt.Errorf("unsupported locale")
}

func genFarewell(ctx context.Context) (string, error) {
    switch locale, err := locale(ctx); {
    case err != nil:
        return "", err
    case locale == "EN/US":
        return "goodbye", nil
    }
    return "", fmt.Errorf("unsupported locale")
}

func locale(ctx context.Context) (string, error) {
    select {
    case <-ctx.Done():
        return "", ctx.Err() 
    case <-time.After(1 * time.Minute):	//等待一分钟后执行
    }
    return "EN/US", nil
}

//结果
cannot print greeting: context deadline exceeded
cannot print farewell: context canceled

可以在这个程序上进一步改进:因为已知locale需要大约一分钟的时间才能运行,所以可以在locale中检查是否给出了deadline,如果给出了,则返回一个context包预设的错误——DeadlineExceeded

可以看到最终结果是一样的,但是会马上得出执行结果,而不会被阻塞1秒钟

func main() {
    var wg sync.WaitGroup
    ctx, cancel := context.WithCancel(context.Background())
    defer cancel()

    wg.Add(1)
    go func() {
        defer wg.Done()

        if err := printGreeting(ctx); err != nil {
            fmt.Printf("cannot print greeting: %v\n", err)
            cancel()
        }
    }()
    wg.Add(1)
    go func() {
        defer wg.Done()
        if err := printFarewell(ctx); err != nil {
            fmt.Printf("cannot print farewell: %v\n", err)
        }
    }()

    wg.Wait()
}

func printGreeting(ctx context.Context) error {
    greeting, err := genGreeting(ctx)
    if err != nil {
        return err
    }
    fmt.Printf("%s world!\n", greeting)
    return nil
}

func printFarewell(ctx context.Context) error {
    farewell, err := genFarewell(ctx)
    if err != nil {
        return err
    }
    fmt.Printf("%s world!\n", farewell)
    return nil
}

func genGreeting(ctx context.Context) (string, error) {
    ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
    defer cancel()

    switch locale, err := locale(ctx); {
    case err != nil:
        return "", err
    case locale == "EN/US":
        return "hello", nil
    }
    return "", fmt.Errorf("unsupported locale")
}

func genFarewell(ctx context.Context) (string, error) {
    switch locale, err := locale(ctx); {
    case err != nil:
        return "", err
    case locale == "EN/US":
        return "goodbye", nil
    }
    return "", fmt.Errorf("unsupported locale")
}

func locale(ctx context.Context) (string, error) {
    if deadline, ok := ctx.Deadline(); ok { //1
        if deadline.Sub(time.Now().Add(1*time.Minute)) <= 0 {
            return "", context.DeadlineExceeded
        }
    }

    select {
    case <-ctx.Done():
        return "", ctx.Err()
    case <-time.After(1 * time.Minute):
    }
    return "EN/US", nil
}

//结果
cannot print greeting: context deadline exceeded
cannot print farewell: context canceled

4.6 协程取消信号同步

在并发程序中,由于超时、取消操作或者一些异常情况,往往需要进行抢占操作或者中断后续操作,如下例可以采用channel的方式控制,这里采用主协程main控制通道的关闭,子协程监听done,一旦主协程关闭了channel,那么子协程就可以退出了

因为这个例子还不复杂,所以用通道控制感觉还可以,但是当有多个主协程和多个子协程时,就要定义多个done channel,这将变得非常混乱

func main() {
   messages := make(chan int, 10)
   done := make(chan bool)

   defer close(messages)
   // consumer
   go func() {
      ticker := time.NewTicker(1 * time.Second)
      for _ = range ticker.C {
         select {
         case <-done:
            fmt.Println("child process interrupt...")
            return
         default:
            fmt.Printf("send message: %d\n", <-messages)
         }
      }
   }()

   // producer
   for i := 0; i < 10; i++ {
      messages <- i
   }
   time.Sleep(5 * time.Second)
   close(done)
   time.Sleep(1 * time.Second)
   fmt.Println("main process exit!")
}

//结果
send message: 0
send message: 1
send message: 2
send message: 3
send message: 4
child process interrupt...
main process exit!

可以试试用context改写解决这个问题,如下例,只要让子线程监听主线程传入的ctx,一旦ctx.Done()返回空channel,子线程即可取消执行任务

func main() {
	messages := make(chan int, 10)

	// producer
	for i := 0; i < 10; i++ {
		messages <- i
	}

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)

	// consumer
	go func(ctx context.Context) {
		ticker := time.NewTicker(1 * time.Second)
		for _ = range ticker.C {
			select {
			case <-ctx.Done():
				fmt.Println("child process interrupt...")
				return
			default:
				fmt.Printf("send message: %d\n", <-messages)
			}
		}
	}(ctx)

	defer close(messages)
	defer cancel()

	select {
	case <-ctx.Done():
		time.Sleep(1 * time.Second)
		fmt.Println("main process exit!")
	}
}

//结果
send message: 0
send message: 1
send message: 2
send message: 3
send message: 4
child process interrupt...
main process exit!

5. Context在源码中使用案例

5.1 sql

// conn returns a newly-opened or cached *driverConn.
func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
	db.mu.Lock()
	if db.closed {
		db.mu.Unlock()
		return nil, errDBClosed
	}
	// Check if the context is expired.
	// 检查context是否超时
	select {
	default:
	case <-ctx.Done():
		db.mu.Unlock()
		return nil, ctx.Err()
	}
	
	......
}

5.2 http

type Request struct {
    ......
    
    // Response is the redirect response which caused this request
	// to be created. This field is only populated during client
	// redirects.
	Response *Response

	// ctx is either the client or server context. It should only
	// be modified via copying the whole Request using WithContext.
	// It is unexported to prevent people from using Context wrong
	// and mutating the contexts held by callers of the same request.
	ctx context.Context //context的使用
}

注意事项:不要把context用作结构体字段,除非该结构体本身也是表达一个上下文的概念

6.参考链接

https://juejin.cn/post/6844904070667321357#heading-14

https://www.topgoer.cn/docs/goquestions/goquestions-1cjh3l2qioavp

https://www.topgoer.cn/docs/gozhuanjia/chapter055.3-context

https://www.topgoer.cn/docs/concurrency/concurrency-1clit4ftnpbh3

极客时间go实战训练营