1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "io"
19 "net"
20 "sync"
21 "sync/atomic"
22 "time"
23 )
24
25
26
27 type Conn struct {
28
29 conn net.Conn
30 isClient bool
31 handshakeFn func(context.Context) error
32
33
34
35
36 handshakeStatus uint32
37
38 handshakeMutex sync.Mutex
39 handshakeErr error
40 vers uint16
41 haveVers bool
42 config *Config
43
44
45
46 handshakes int
47 didResume bool
48 cipherSuite uint16
49 ocspResponse []byte
50 scts [][]byte
51 peerCertificates []*x509.Certificate
52
53
54 verifiedChains [][]*x509.Certificate
55
56 serverName string
57
58
59
60 secureRenegotiation bool
61
62 ekm func(label string, context []byte, length int) ([]byte, error)
63
64
65 resumptionSecret []byte
66
67
68
69
70 ticketKeys []ticketKey
71
72
73
74
75
76 clientFinishedIsFirst bool
77
78
79 closeNotifyErr error
80
81
82 closeNotifySent bool
83
84
85
86
87
88 clientFinished [12]byte
89 serverFinished [12]byte
90
91
92 clientProtocol string
93
94
95 in, out halfConn
96 rawInput bytes.Buffer
97 input bytes.Reader
98 hand bytes.Buffer
99 buffering bool
100 sendBuf []byte
101
102
103
104 bytesSent int64
105 packetsSent int64
106
107
108
109
110 retryCount int
111
112
113
114
115 activeCall int32
116
117 tmp [16]byte
118 }
119
120
121
122
123
124
125 func (c *Conn) LocalAddr() net.Addr {
126 return c.conn.LocalAddr()
127 }
128
129
130 func (c *Conn) RemoteAddr() net.Addr {
131 return c.conn.RemoteAddr()
132 }
133
134
135
136
137 func (c *Conn) SetDeadline(t time.Time) error {
138 return c.conn.SetDeadline(t)
139 }
140
141
142
143 func (c *Conn) SetReadDeadline(t time.Time) error {
144 return c.conn.SetReadDeadline(t)
145 }
146
147
148
149
150 func (c *Conn) SetWriteDeadline(t time.Time) error {
151 return c.conn.SetWriteDeadline(t)
152 }
153
154
155
156 type halfConn struct {
157 sync.Mutex
158
159 err error
160 version uint16
161 cipher interface{}
162 mac hash.Hash
163 seq [8]byte
164
165 scratchBuf [13]byte
166
167 nextCipher interface{}
168 nextMac hash.Hash
169
170 trafficSecret []byte
171 }
172
173 type permanentError struct {
174 err net.Error
175 }
176
177 func (e *permanentError) Error() string { return e.err.Error() }
178 func (e *permanentError) Unwrap() error { return e.err }
179 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
180 func (e *permanentError) Temporary() bool { return false }
181
182 func (hc *halfConn) setErrorLocked(err error) error {
183 if e, ok := err.(net.Error); ok {
184 hc.err = &permanentError{err: e}
185 } else {
186 hc.err = err
187 }
188 return hc.err
189 }
190
191
192
193 func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) {
194 hc.version = version
195 hc.nextCipher = cipher
196 hc.nextMac = mac
197 }
198
199
200
201 func (hc *halfConn) changeCipherSpec() error {
202 if hc.nextCipher == nil || hc.version == VersionTLS13 {
203 return alertInternalError
204 }
205 hc.cipher = hc.nextCipher
206 hc.mac = hc.nextMac
207 hc.nextCipher = nil
208 hc.nextMac = nil
209 for i := range hc.seq {
210 hc.seq[i] = 0
211 }
212 return nil
213 }
214
215 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
216 hc.trafficSecret = secret
217 key, iv := suite.trafficKey(secret)
218 hc.cipher = suite.aead(key, iv)
219 for i := range hc.seq {
220 hc.seq[i] = 0
221 }
222 }
223
224
225 func (hc *halfConn) incSeq() {
226 for i := 7; i >= 0; i-- {
227 hc.seq[i]++
228 if hc.seq[i] != 0 {
229 return
230 }
231 }
232
233
234
235
236 panic("TLS: sequence number wraparound")
237 }
238
239
240
241
242 func (hc *halfConn) explicitNonceLen() int {
243 if hc.cipher == nil {
244 return 0
245 }
246
247 switch c := hc.cipher.(type) {
248 case cipher.Stream:
249 return 0
250 case aead:
251 return c.explicitNonceLen()
252 case cbcMode:
253
254 if hc.version >= VersionTLS11 {
255 return c.BlockSize()
256 }
257 return 0
258 default:
259 panic("unknown cipher type")
260 }
261 }
262
263
264
265
266 func extractPadding(payload []byte) (toRemove int, good byte) {
267 if len(payload) < 1 {
268 return 0, 0
269 }
270
271 paddingLen := payload[len(payload)-1]
272 t := uint(len(payload)-1) - uint(paddingLen)
273
274 good = byte(int32(^t) >> 31)
275
276
277 toCheck := 256
278
279 if toCheck > len(payload) {
280 toCheck = len(payload)
281 }
282
283 for i := 0; i < toCheck; i++ {
284 t := uint(paddingLen) - uint(i)
285
286 mask := byte(int32(^t) >> 31)
287 b := payload[len(payload)-1-i]
288 good &^= mask&paddingLen ^ mask&b
289 }
290
291
292
293 good &= good << 4
294 good &= good << 2
295 good &= good << 1
296 good = uint8(int8(good) >> 7)
297
298
299
300
301
302
303
304
305
306
307 paddingLen &= good
308
309 toRemove = int(paddingLen) + 1
310 return
311 }
312
313 func roundUp(a, b int) int {
314 return a + (b-a%b)%b
315 }
316
317
318 type cbcMode interface {
319 cipher.BlockMode
320 SetIV([]byte)
321 }
322
323
324
325 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
326 var plaintext []byte
327 typ := recordType(record[0])
328 payload := record[recordHeaderLen:]
329
330
331
332 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
333 return payload, typ, nil
334 }
335
336 paddingGood := byte(255)
337 paddingLen := 0
338
339 explicitNonceLen := hc.explicitNonceLen()
340
341 if hc.cipher != nil {
342 switch c := hc.cipher.(type) {
343 case cipher.Stream:
344 c.XORKeyStream(payload, payload)
345 case aead:
346 if len(payload) < explicitNonceLen {
347 return nil, 0, alertBadRecordMAC
348 }
349 nonce := payload[:explicitNonceLen]
350 if len(nonce) == 0 {
351 nonce = hc.seq[:]
352 }
353 payload = payload[explicitNonceLen:]
354
355 var additionalData []byte
356 if hc.version == VersionTLS13 {
357 additionalData = record[:recordHeaderLen]
358 } else {
359 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
360 additionalData = append(additionalData, record[:3]...)
361 n := len(payload) - c.Overhead()
362 additionalData = append(additionalData, byte(n>>8), byte(n))
363 }
364
365 var err error
366 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
367 if err != nil {
368 return nil, 0, alertBadRecordMAC
369 }
370 case cbcMode:
371 blockSize := c.BlockSize()
372 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
373 if len(payload)%blockSize != 0 || len(payload) < minPayload {
374 return nil, 0, alertBadRecordMAC
375 }
376
377 if explicitNonceLen > 0 {
378 c.SetIV(payload[:explicitNonceLen])
379 payload = payload[explicitNonceLen:]
380 }
381 c.CryptBlocks(payload, payload)
382
383
384
385
386
387
388
389 paddingLen, paddingGood = extractPadding(payload)
390 default:
391 panic("unknown cipher type")
392 }
393
394 if hc.version == VersionTLS13 {
395 if typ != recordTypeApplicationData {
396 return nil, 0, alertUnexpectedMessage
397 }
398 if len(plaintext) > maxPlaintext+1 {
399 return nil, 0, alertRecordOverflow
400 }
401
402 for i := len(plaintext) - 1; i >= 0; i-- {
403 if plaintext[i] != 0 {
404 typ = recordType(plaintext[i])
405 plaintext = plaintext[:i]
406 break
407 }
408 if i == 0 {
409 return nil, 0, alertUnexpectedMessage
410 }
411 }
412 }
413 } else {
414 plaintext = payload
415 }
416
417 if hc.mac != nil {
418 macSize := hc.mac.Size()
419 if len(payload) < macSize {
420 return nil, 0, alertBadRecordMAC
421 }
422
423 n := len(payload) - macSize - paddingLen
424 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
425 record[3] = byte(n >> 8)
426 record[4] = byte(n)
427 remoteMAC := payload[n : n+macSize]
428 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
429
430
431
432
433
434
435
436
437 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
438 if macAndPaddingGood != 1 {
439 return nil, 0, alertBadRecordMAC
440 }
441
442 plaintext = payload[:n]
443 }
444
445 hc.incSeq()
446 return plaintext, typ, nil
447 }
448
449
450
451
452 func sliceForAppend(in []byte, n int) (head, tail []byte) {
453 if total := len(in) + n; cap(in) >= total {
454 head = in[:total]
455 } else {
456 head = make([]byte, total)
457 copy(head, in)
458 }
459 tail = head[len(in):]
460 return
461 }
462
463
464
465 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
466 if hc.cipher == nil {
467 return append(record, payload...), nil
468 }
469
470 var explicitNonce []byte
471 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
472 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
473 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
474
475
476
477
478
479
480
481
482
483 copy(explicitNonce, hc.seq[:])
484 } else {
485 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
486 return nil, err
487 }
488 }
489 }
490
491 var dst []byte
492 switch c := hc.cipher.(type) {
493 case cipher.Stream:
494 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
495 record, dst = sliceForAppend(record, len(payload)+len(mac))
496 c.XORKeyStream(dst[:len(payload)], payload)
497 c.XORKeyStream(dst[len(payload):], mac)
498 case aead:
499 nonce := explicitNonce
500 if len(nonce) == 0 {
501 nonce = hc.seq[:]
502 }
503
504 if hc.version == VersionTLS13 {
505 record = append(record, payload...)
506
507
508 record = append(record, record[0])
509 record[0] = byte(recordTypeApplicationData)
510
511 n := len(payload) + 1 + c.Overhead()
512 record[3] = byte(n >> 8)
513 record[4] = byte(n)
514
515 record = c.Seal(record[:recordHeaderLen],
516 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
517 } else {
518 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
519 additionalData = append(additionalData, record[:recordHeaderLen]...)
520 record = c.Seal(record, nonce, payload, additionalData)
521 }
522 case cbcMode:
523 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
524 blockSize := c.BlockSize()
525 plaintextLen := len(payload) + len(mac)
526 paddingLen := blockSize - plaintextLen%blockSize
527 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
528 copy(dst, payload)
529 copy(dst[len(payload):], mac)
530 for i := plaintextLen; i < len(dst); i++ {
531 dst[i] = byte(paddingLen - 1)
532 }
533 if len(explicitNonce) > 0 {
534 c.SetIV(explicitNonce)
535 }
536 c.CryptBlocks(dst, dst)
537 default:
538 panic("unknown cipher type")
539 }
540
541
542 n := len(record) - recordHeaderLen
543 record[3] = byte(n >> 8)
544 record[4] = byte(n)
545 hc.incSeq()
546
547 return record, nil
548 }
549
550
551 type RecordHeaderError struct {
552
553 Msg string
554
555
556 RecordHeader [5]byte
557
558
559
560
561 Conn net.Conn
562 }
563
564 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
565
566 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
567 err.Msg = msg
568 err.Conn = conn
569 copy(err.RecordHeader[:], c.rawInput.Bytes())
570 return err
571 }
572
573 func (c *Conn) readRecord() error {
574 return c.readRecordOrCCS(false)
575 }
576
577 func (c *Conn) readChangeCipherSpec() error {
578 return c.readRecordOrCCS(true)
579 }
580
581
582
583
584
585
586
587
588
589
590
591
592
593 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
594 if c.in.err != nil {
595 return c.in.err
596 }
597 handshakeComplete := c.handshakeComplete()
598
599
600 if c.input.Len() != 0 {
601 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
602 }
603 c.input.Reset(nil)
604
605
606 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
607
608
609
610 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
611 err = io.EOF
612 }
613 if e, ok := err.(net.Error); !ok || !e.Temporary() {
614 c.in.setErrorLocked(err)
615 }
616 return err
617 }
618 hdr := c.rawInput.Bytes()[:recordHeaderLen]
619 typ := recordType(hdr[0])
620
621
622
623
624
625 if !handshakeComplete && typ == 0x80 {
626 c.sendAlert(alertProtocolVersion)
627 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
628 }
629
630 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
631 n := int(hdr[3])<<8 | int(hdr[4])
632 if c.haveVers && c.vers != VersionTLS13 && vers != c.vers {
633 c.sendAlert(alertProtocolVersion)
634 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
635 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
636 }
637 if !c.haveVers {
638
639
640
641
642 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
643 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
644 }
645 }
646 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
647 c.sendAlert(alertRecordOverflow)
648 msg := fmt.Sprintf("oversized record received with length %d", n)
649 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
650 }
651 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
652 if e, ok := err.(net.Error); !ok || !e.Temporary() {
653 c.in.setErrorLocked(err)
654 }
655 return err
656 }
657
658
659 record := c.rawInput.Next(recordHeaderLen + n)
660 data, typ, err := c.in.decrypt(record)
661 if err != nil {
662 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
663 }
664 if len(data) > maxPlaintext {
665 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
666 }
667
668
669 if c.in.cipher == nil && typ == recordTypeApplicationData {
670 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
671 }
672
673 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
674
675 c.retryCount = 0
676 }
677
678
679 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
680 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
681 }
682
683 switch typ {
684 default:
685 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
686
687 case recordTypeAlert:
688 if len(data) != 2 {
689 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
690 }
691 if alert(data[1]) == alertCloseNotify {
692 return c.in.setErrorLocked(io.EOF)
693 }
694 if c.vers == VersionTLS13 {
695 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
696 }
697 switch data[0] {
698 case alertLevelWarning:
699
700 return c.retryReadRecord(expectChangeCipherSpec)
701 case alertLevelError:
702 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
703 default:
704 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
705 }
706
707 case recordTypeChangeCipherSpec:
708 if len(data) != 1 || data[0] != 1 {
709 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
710 }
711
712 if c.hand.Len() > 0 {
713 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
714 }
715
716
717
718
719
720 if c.vers == VersionTLS13 {
721 return c.retryReadRecord(expectChangeCipherSpec)
722 }
723 if !expectChangeCipherSpec {
724 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
725 }
726 if err := c.in.changeCipherSpec(); err != nil {
727 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
728 }
729
730 case recordTypeApplicationData:
731 if !handshakeComplete || expectChangeCipherSpec {
732 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
733 }
734
735
736 if len(data) == 0 {
737 return c.retryReadRecord(expectChangeCipherSpec)
738 }
739
740
741
742 c.input.Reset(data)
743
744 case recordTypeHandshake:
745 if len(data) == 0 || expectChangeCipherSpec {
746 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
747 }
748 c.hand.Write(data)
749 }
750
751 return nil
752 }
753
754
755
756 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
757 c.retryCount++
758 if c.retryCount > maxUselessRecords {
759 c.sendAlert(alertUnexpectedMessage)
760 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
761 }
762 return c.readRecordOrCCS(expectChangeCipherSpec)
763 }
764
765
766
767
768 type atLeastReader struct {
769 R io.Reader
770 N int64
771 }
772
773 func (r *atLeastReader) Read(p []byte) (int, error) {
774 if r.N <= 0 {
775 return 0, io.EOF
776 }
777 n, err := r.R.Read(p)
778 r.N -= int64(n)
779 if r.N > 0 && err == io.EOF {
780 return n, io.ErrUnexpectedEOF
781 }
782 if r.N <= 0 && err == nil {
783 return n, io.EOF
784 }
785 return n, err
786 }
787
788
789
790 func (c *Conn) readFromUntil(r io.Reader, n int) error {
791 if c.rawInput.Len() >= n {
792 return nil
793 }
794 needs := n - c.rawInput.Len()
795
796
797
798 c.rawInput.Grow(needs + bytes.MinRead)
799 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
800 return err
801 }
802
803
804 func (c *Conn) sendAlertLocked(err alert) error {
805 switch err {
806 case alertNoRenegotiation, alertCloseNotify:
807 c.tmp[0] = alertLevelWarning
808 default:
809 c.tmp[0] = alertLevelError
810 }
811 c.tmp[1] = byte(err)
812
813 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
814 if err == alertCloseNotify {
815
816 return writeErr
817 }
818
819 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
820 }
821
822
823 func (c *Conn) sendAlert(err alert) error {
824 c.out.Lock()
825 defer c.out.Unlock()
826 return c.sendAlertLocked(err)
827 }
828
829 const (
830
831
832
833
834
835 tcpMSSEstimate = 1208
836
837
838
839
840 recordSizeBoostThreshold = 128 * 1024
841 )
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
860 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
861 return maxPlaintext
862 }
863
864 if c.bytesSent >= recordSizeBoostThreshold {
865 return maxPlaintext
866 }
867
868
869 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
870 if c.out.cipher != nil {
871 switch ciph := c.out.cipher.(type) {
872 case cipher.Stream:
873 payloadBytes -= c.out.mac.Size()
874 case cipher.AEAD:
875 payloadBytes -= ciph.Overhead()
876 case cbcMode:
877 blockSize := ciph.BlockSize()
878
879
880 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
881
882
883 payloadBytes -= c.out.mac.Size()
884 default:
885 panic("unknown cipher type")
886 }
887 }
888 if c.vers == VersionTLS13 {
889 payloadBytes--
890 }
891
892
893 pkt := c.packetsSent
894 c.packetsSent++
895 if pkt > 1000 {
896 return maxPlaintext
897 }
898
899 n := payloadBytes * int(pkt+1)
900 if n > maxPlaintext {
901 n = maxPlaintext
902 }
903 return n
904 }
905
906 func (c *Conn) write(data []byte) (int, error) {
907 if c.buffering {
908 c.sendBuf = append(c.sendBuf, data...)
909 return len(data), nil
910 }
911
912 n, err := c.conn.Write(data)
913 c.bytesSent += int64(n)
914 return n, err
915 }
916
917 func (c *Conn) flush() (int, error) {
918 if len(c.sendBuf) == 0 {
919 return 0, nil
920 }
921
922 n, err := c.conn.Write(c.sendBuf)
923 c.bytesSent += int64(n)
924 c.sendBuf = nil
925 c.buffering = false
926 return n, err
927 }
928
929
930 var outBufPool = sync.Pool{
931 New: func() interface{} {
932 return new([]byte)
933 },
934 }
935
936
937
938 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
939 outBufPtr := outBufPool.Get().(*[]byte)
940 outBuf := *outBufPtr
941 defer func() {
942
943
944
945
946
947 *outBufPtr = outBuf
948 outBufPool.Put(outBufPtr)
949 }()
950
951 var n int
952 for len(data) > 0 {
953 m := len(data)
954 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
955 m = maxPayload
956 }
957
958 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
959 outBuf[0] = byte(typ)
960 vers := c.vers
961 if vers == 0 {
962
963
964 vers = VersionTLS10
965 } else if vers == VersionTLS13 {
966
967
968 vers = VersionTLS12
969 }
970 outBuf[1] = byte(vers >> 8)
971 outBuf[2] = byte(vers)
972 outBuf[3] = byte(m >> 8)
973 outBuf[4] = byte(m)
974
975 var err error
976 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
977 if err != nil {
978 return n, err
979 }
980 if _, err := c.write(outBuf); err != nil {
981 return n, err
982 }
983 n += m
984 data = data[m:]
985 }
986
987 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
988 if err := c.out.changeCipherSpec(); err != nil {
989 return n, c.sendAlertLocked(err.(alert))
990 }
991 }
992
993 return n, nil
994 }
995
996
997
998 func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
999 c.out.Lock()
1000 defer c.out.Unlock()
1001
1002 return c.writeRecordLocked(typ, data)
1003 }
1004
1005
1006
1007 func (c *Conn) readHandshake() (interface{}, error) {
1008 for c.hand.Len() < 4 {
1009 if err := c.readRecord(); err != nil {
1010 return nil, err
1011 }
1012 }
1013
1014 data := c.hand.Bytes()
1015 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1016 if n > maxHandshake {
1017 c.sendAlertLocked(alertInternalError)
1018 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1019 }
1020 for c.hand.Len() < 4+n {
1021 if err := c.readRecord(); err != nil {
1022 return nil, err
1023 }
1024 }
1025 data = c.hand.Next(4 + n)
1026 var m handshakeMessage
1027 switch data[0] {
1028 case typeHelloRequest:
1029 m = new(helloRequestMsg)
1030 case typeClientHello:
1031 m = new(clientHelloMsg)
1032 case typeServerHello:
1033 m = new(serverHelloMsg)
1034 case typeNewSessionTicket:
1035 if c.vers == VersionTLS13 {
1036 m = new(newSessionTicketMsgTLS13)
1037 } else {
1038 m = new(newSessionTicketMsg)
1039 }
1040 case typeCertificate:
1041 if c.vers == VersionTLS13 {
1042 m = new(certificateMsgTLS13)
1043 } else {
1044 m = new(certificateMsg)
1045 }
1046 case typeCertificateRequest:
1047 if c.vers == VersionTLS13 {
1048 m = new(certificateRequestMsgTLS13)
1049 } else {
1050 m = &certificateRequestMsg{
1051 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1052 }
1053 }
1054 case typeCertificateStatus:
1055 m = new(certificateStatusMsg)
1056 case typeServerKeyExchange:
1057 m = new(serverKeyExchangeMsg)
1058 case typeServerHelloDone:
1059 m = new(serverHelloDoneMsg)
1060 case typeClientKeyExchange:
1061 m = new(clientKeyExchangeMsg)
1062 case typeCertificateVerify:
1063 m = &certificateVerifyMsg{
1064 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1065 }
1066 case typeFinished:
1067 m = new(finishedMsg)
1068 case typeEncryptedExtensions:
1069 m = new(encryptedExtensionsMsg)
1070 case typeEndOfEarlyData:
1071 m = new(endOfEarlyDataMsg)
1072 case typeKeyUpdate:
1073 m = new(keyUpdateMsg)
1074 default:
1075 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1076 }
1077
1078
1079
1080
1081 data = append([]byte(nil), data...)
1082
1083 if !m.unmarshal(data) {
1084 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1085 }
1086 return m, nil
1087 }
1088
1089 var (
1090 errShutdown = errors.New("tls: protocol is shutdown")
1091 )
1092
1093
1094
1095
1096
1097
1098
1099 func (c *Conn) Write(b []byte) (int, error) {
1100
1101 for {
1102 x := atomic.LoadInt32(&c.activeCall)
1103 if x&1 != 0 {
1104 return 0, net.ErrClosed
1105 }
1106 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
1107 break
1108 }
1109 }
1110 defer atomic.AddInt32(&c.activeCall, -2)
1111
1112 if err := c.Handshake(); err != nil {
1113 return 0, err
1114 }
1115
1116 c.out.Lock()
1117 defer c.out.Unlock()
1118
1119 if err := c.out.err; err != nil {
1120 return 0, err
1121 }
1122
1123 if !c.handshakeComplete() {
1124 return 0, alertInternalError
1125 }
1126
1127 if c.closeNotifySent {
1128 return 0, errShutdown
1129 }
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140 var m int
1141 if len(b) > 1 && c.vers == VersionTLS10 {
1142 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1143 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1144 if err != nil {
1145 return n, c.out.setErrorLocked(err)
1146 }
1147 m, b = 1, b[1:]
1148 }
1149 }
1150
1151 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1152 return n + m, c.out.setErrorLocked(err)
1153 }
1154
1155
1156 func (c *Conn) handleRenegotiation() error {
1157 if c.vers == VersionTLS13 {
1158 return errors.New("tls: internal error: unexpected renegotiation")
1159 }
1160
1161 msg, err := c.readHandshake()
1162 if err != nil {
1163 return err
1164 }
1165
1166 helloReq, ok := msg.(*helloRequestMsg)
1167 if !ok {
1168 c.sendAlert(alertUnexpectedMessage)
1169 return unexpectedMessageError(helloReq, msg)
1170 }
1171
1172 if !c.isClient {
1173 return c.sendAlert(alertNoRenegotiation)
1174 }
1175
1176 switch c.config.Renegotiation {
1177 case RenegotiateNever:
1178 return c.sendAlert(alertNoRenegotiation)
1179 case RenegotiateOnceAsClient:
1180 if c.handshakes > 1 {
1181 return c.sendAlert(alertNoRenegotiation)
1182 }
1183 case RenegotiateFreelyAsClient:
1184
1185 default:
1186 c.sendAlert(alertInternalError)
1187 return errors.New("tls: unknown Renegotiation value")
1188 }
1189
1190 c.handshakeMutex.Lock()
1191 defer c.handshakeMutex.Unlock()
1192
1193 atomic.StoreUint32(&c.handshakeStatus, 0)
1194 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1195 c.handshakes++
1196 }
1197 return c.handshakeErr
1198 }
1199
1200
1201
1202 func (c *Conn) handlePostHandshakeMessage() error {
1203 if c.vers != VersionTLS13 {
1204 return c.handleRenegotiation()
1205 }
1206
1207 msg, err := c.readHandshake()
1208 if err != nil {
1209 return err
1210 }
1211
1212 c.retryCount++
1213 if c.retryCount > maxUselessRecords {
1214 c.sendAlert(alertUnexpectedMessage)
1215 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1216 }
1217
1218 switch msg := msg.(type) {
1219 case *newSessionTicketMsgTLS13:
1220 return c.handleNewSessionTicket(msg)
1221 case *keyUpdateMsg:
1222 return c.handleKeyUpdate(msg)
1223 default:
1224 c.sendAlert(alertUnexpectedMessage)
1225 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1226 }
1227 }
1228
1229 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1230 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1231 if cipherSuite == nil {
1232 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1233 }
1234
1235 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1236 c.in.setTrafficSecret(cipherSuite, newSecret)
1237
1238 if keyUpdate.updateRequested {
1239 c.out.Lock()
1240 defer c.out.Unlock()
1241
1242 msg := &keyUpdateMsg{}
1243 _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal())
1244 if err != nil {
1245
1246 c.out.setErrorLocked(err)
1247 return nil
1248 }
1249
1250 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1251 c.out.setTrafficSecret(cipherSuite, newSecret)
1252 }
1253
1254 return nil
1255 }
1256
1257
1258
1259
1260
1261
1262
1263 func (c *Conn) Read(b []byte) (int, error) {
1264 if err := c.Handshake(); err != nil {
1265 return 0, err
1266 }
1267 if len(b) == 0 {
1268
1269
1270 return 0, nil
1271 }
1272
1273 c.in.Lock()
1274 defer c.in.Unlock()
1275
1276 for c.input.Len() == 0 {
1277 if err := c.readRecord(); err != nil {
1278 return 0, err
1279 }
1280 for c.hand.Len() > 0 {
1281 if err := c.handlePostHandshakeMessage(); err != nil {
1282 return 0, err
1283 }
1284 }
1285 }
1286
1287 n, _ := c.input.Read(b)
1288
1289
1290
1291
1292
1293
1294
1295
1296 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1297 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1298 if err := c.readRecord(); err != nil {
1299 return n, err
1300 }
1301 }
1302
1303 return n, nil
1304 }
1305
1306
1307 func (c *Conn) Close() error {
1308
1309 var x int32
1310 for {
1311 x = atomic.LoadInt32(&c.activeCall)
1312 if x&1 != 0 {
1313 return net.ErrClosed
1314 }
1315 if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
1316 break
1317 }
1318 }
1319 if x != 0 {
1320
1321
1322
1323
1324
1325
1326 return c.conn.Close()
1327 }
1328
1329 var alertErr error
1330 if c.handshakeComplete() {
1331 if err := c.closeNotify(); err != nil {
1332 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1333 }
1334 }
1335
1336 if err := c.conn.Close(); err != nil {
1337 return err
1338 }
1339 return alertErr
1340 }
1341
1342 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1343
1344
1345
1346
1347 func (c *Conn) CloseWrite() error {
1348 if !c.handshakeComplete() {
1349 return errEarlyCloseWrite
1350 }
1351
1352 return c.closeNotify()
1353 }
1354
1355 func (c *Conn) closeNotify() error {
1356 c.out.Lock()
1357 defer c.out.Unlock()
1358
1359 if !c.closeNotifySent {
1360
1361 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1362 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1363 c.closeNotifySent = true
1364
1365 c.SetWriteDeadline(time.Now())
1366 }
1367 return c.closeNotifyErr
1368 }
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378 func (c *Conn) Handshake() error {
1379 return c.HandshakeContext(context.Background())
1380 }
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392 func (c *Conn) HandshakeContext(ctx context.Context) error {
1393
1394
1395 return c.handshakeContext(ctx)
1396 }
1397
1398 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1399 handshakeCtx, cancel := context.WithCancel(ctx)
1400
1401
1402
1403 defer cancel()
1404
1405
1406
1407
1408
1409
1410 if ctx.Done() != nil {
1411 done := make(chan struct{})
1412 interruptRes := make(chan error, 1)
1413 defer func() {
1414 close(done)
1415 if ctxErr := <-interruptRes; ctxErr != nil {
1416
1417 ret = ctxErr
1418 }
1419 }()
1420 go func() {
1421 select {
1422 case <-handshakeCtx.Done():
1423
1424 _ = c.conn.Close()
1425 interruptRes <- handshakeCtx.Err()
1426 case <-done:
1427 interruptRes <- nil
1428 }
1429 }()
1430 }
1431
1432 c.handshakeMutex.Lock()
1433 defer c.handshakeMutex.Unlock()
1434
1435 if err := c.handshakeErr; err != nil {
1436 return err
1437 }
1438 if c.handshakeComplete() {
1439 return nil
1440 }
1441
1442 c.in.Lock()
1443 defer c.in.Unlock()
1444
1445 c.handshakeErr = c.handshakeFn(handshakeCtx)
1446 if c.handshakeErr == nil {
1447 c.handshakes++
1448 } else {
1449
1450
1451 c.flush()
1452 }
1453
1454 if c.handshakeErr == nil && !c.handshakeComplete() {
1455 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1456 }
1457
1458 return c.handshakeErr
1459 }
1460
1461
1462 func (c *Conn) ConnectionState() ConnectionState {
1463 c.handshakeMutex.Lock()
1464 defer c.handshakeMutex.Unlock()
1465 return c.connectionStateLocked()
1466 }
1467
1468 func (c *Conn) connectionStateLocked() ConnectionState {
1469 var state ConnectionState
1470 state.HandshakeComplete = c.handshakeComplete()
1471 state.Version = c.vers
1472 state.NegotiatedProtocol = c.clientProtocol
1473 state.DidResume = c.didResume
1474 state.NegotiatedProtocolIsMutual = true
1475 state.ServerName = c.serverName
1476 state.CipherSuite = c.cipherSuite
1477 state.PeerCertificates = c.peerCertificates
1478 state.VerifiedChains = c.verifiedChains
1479 state.SignedCertificateTimestamps = c.scts
1480 state.OCSPResponse = c.ocspResponse
1481 if !c.didResume && c.vers != VersionTLS13 {
1482 if c.clientFinishedIsFirst {
1483 state.TLSUnique = c.clientFinished[:]
1484 } else {
1485 state.TLSUnique = c.serverFinished[:]
1486 }
1487 }
1488 if c.config.Renegotiation != RenegotiateNever {
1489 state.ekm = noExportedKeyingMaterial
1490 } else {
1491 state.ekm = c.ekm
1492 }
1493 return state
1494 }
1495
1496
1497
1498 func (c *Conn) OCSPResponse() []byte {
1499 c.handshakeMutex.Lock()
1500 defer c.handshakeMutex.Unlock()
1501
1502 return c.ocspResponse
1503 }
1504
1505
1506
1507
1508 func (c *Conn) VerifyHostname(host string) error {
1509 c.handshakeMutex.Lock()
1510 defer c.handshakeMutex.Unlock()
1511 if !c.isClient {
1512 return errors.New("tls: VerifyHostname called on TLS server connection")
1513 }
1514 if !c.handshakeComplete() {
1515 return errors.New("tls: handshake has not yet been performed")
1516 }
1517 if len(c.verifiedChains) == 0 {
1518 return errors.New("tls: handshake did not verify certificate chain")
1519 }
1520 return c.peerCertificates[0].VerifyHostname(host)
1521 }
1522
1523 func (c *Conn) handshakeComplete() bool {
1524 return atomic.LoadUint32(&c.handshakeStatus) == 1
1525 }
1526
View as plain text