smtp_test.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. package gomail
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "io"
  6. "net"
  7. "net/smtp"
  8. "reflect"
  9. "testing"
  10. "time"
  11. )
  12. const (
  13. testPort = 587
  14. testSSLPort = 465
  15. )
  16. var (
  17. testConn = &net.TCPConn{}
  18. testTLSConn = &tls.Conn{}
  19. testConfig = &tls.Config{InsecureSkipVerify: true}
  20. testAuth = smtp.PlainAuth("", testUser, testPwd, testHost)
  21. )
  22. func TestDialer(t *testing.T) {
  23. d := NewDialer(testHost, testPort, "user", "pwd")
  24. testSendMail(t, d, []string{
  25. "Extension STARTTLS",
  26. "StartTLS",
  27. "Extension AUTH",
  28. "Auth",
  29. "Mail " + testFrom,
  30. "Rcpt " + testTo1,
  31. "Rcpt " + testTo2,
  32. "Data",
  33. "Write message",
  34. "Close writer",
  35. "Quit",
  36. "Close",
  37. })
  38. }
  39. func TestDialerSSL(t *testing.T) {
  40. d := NewDialer(testHost, testSSLPort, "user", "pwd")
  41. testSendMail(t, d, []string{
  42. "Extension AUTH",
  43. "Auth",
  44. "Mail " + testFrom,
  45. "Rcpt " + testTo1,
  46. "Rcpt " + testTo2,
  47. "Data",
  48. "Write message",
  49. "Close writer",
  50. "Quit",
  51. "Close",
  52. })
  53. }
  54. func TestDialerConfig(t *testing.T) {
  55. d := NewDialer(testHost, testPort, "user", "pwd")
  56. d.LocalName = "test"
  57. d.TLSConfig = testConfig
  58. testSendMail(t, d, []string{
  59. "Hello test",
  60. "Extension STARTTLS",
  61. "StartTLS",
  62. "Extension AUTH",
  63. "Auth",
  64. "Mail " + testFrom,
  65. "Rcpt " + testTo1,
  66. "Rcpt " + testTo2,
  67. "Data",
  68. "Write message",
  69. "Close writer",
  70. "Quit",
  71. "Close",
  72. })
  73. }
  74. func TestDialerSSLConfig(t *testing.T) {
  75. d := NewDialer(testHost, testSSLPort, "user", "pwd")
  76. d.LocalName = "test"
  77. d.TLSConfig = testConfig
  78. testSendMail(t, d, []string{
  79. "Hello test",
  80. "Extension AUTH",
  81. "Auth",
  82. "Mail " + testFrom,
  83. "Rcpt " + testTo1,
  84. "Rcpt " + testTo2,
  85. "Data",
  86. "Write message",
  87. "Close writer",
  88. "Quit",
  89. "Close",
  90. })
  91. }
  92. func TestDialerNoAuth(t *testing.T) {
  93. d := &Dialer{
  94. Host: testHost,
  95. Port: testPort,
  96. }
  97. testSendMail(t, d, []string{
  98. "Extension STARTTLS",
  99. "StartTLS",
  100. "Mail " + testFrom,
  101. "Rcpt " + testTo1,
  102. "Rcpt " + testTo2,
  103. "Data",
  104. "Write message",
  105. "Close writer",
  106. "Quit",
  107. "Close",
  108. })
  109. }
  110. func TestDialerTimeout(t *testing.T) {
  111. d := &Dialer{
  112. Host: testHost,
  113. Port: testPort,
  114. }
  115. testSendMailTimeout(t, d, []string{
  116. "Extension STARTTLS",
  117. "StartTLS",
  118. "Mail " + testFrom,
  119. "Extension STARTTLS",
  120. "StartTLS",
  121. "Mail " + testFrom,
  122. "Rcpt " + testTo1,
  123. "Rcpt " + testTo2,
  124. "Data",
  125. "Write message",
  126. "Close writer",
  127. "Quit",
  128. "Close",
  129. })
  130. }
  131. type mockClient struct {
  132. t *testing.T
  133. i int
  134. want []string
  135. addr string
  136. config *tls.Config
  137. timeout bool
  138. }
  139. func (c *mockClient) Hello(localName string) error {
  140. c.do("Hello " + localName)
  141. return nil
  142. }
  143. func (c *mockClient) Extension(ext string) (bool, string) {
  144. c.do("Extension " + ext)
  145. return true, ""
  146. }
  147. func (c *mockClient) StartTLS(config *tls.Config) error {
  148. assertConfig(c.t, config, c.config)
  149. c.do("StartTLS")
  150. return nil
  151. }
  152. func (c *mockClient) Auth(a smtp.Auth) error {
  153. if !reflect.DeepEqual(a, testAuth) {
  154. c.t.Errorf("Invalid auth, got %#v, want %#v", a, testAuth)
  155. }
  156. c.do("Auth")
  157. return nil
  158. }
  159. func (c *mockClient) Mail(from string) error {
  160. c.do("Mail " + from)
  161. if c.timeout {
  162. c.timeout = false
  163. return io.EOF
  164. }
  165. return nil
  166. }
  167. func (c *mockClient) Rcpt(to string) error {
  168. c.do("Rcpt " + to)
  169. return nil
  170. }
  171. func (c *mockClient) Data() (io.WriteCloser, error) {
  172. c.do("Data")
  173. return &mockWriter{c: c, want: testMsg}, nil
  174. }
  175. func (c *mockClient) Quit() error {
  176. c.do("Quit")
  177. return nil
  178. }
  179. func (c *mockClient) Close() error {
  180. c.do("Close")
  181. return nil
  182. }
  183. func (c *mockClient) do(cmd string) {
  184. if c.i >= len(c.want) {
  185. c.t.Fatalf("Invalid command %q", cmd)
  186. }
  187. if cmd != c.want[c.i] {
  188. c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i])
  189. }
  190. c.i++
  191. }
  192. type mockWriter struct {
  193. want string
  194. c *mockClient
  195. buf bytes.Buffer
  196. }
  197. func (w *mockWriter) Write(p []byte) (int, error) {
  198. if w.buf.Len() == 0 {
  199. w.c.do("Write message")
  200. }
  201. w.buf.Write(p)
  202. return len(p), nil
  203. }
  204. func (w *mockWriter) Close() error {
  205. compareBodies(w.c.t, w.buf.String(), w.want)
  206. w.c.do("Close writer")
  207. return nil
  208. }
  209. func testSendMail(t *testing.T, d *Dialer, want []string) {
  210. doTestSendMail(t, d, want, false)
  211. }
  212. func testSendMailTimeout(t *testing.T, d *Dialer, want []string) {
  213. doTestSendMail(t, d, want, true)
  214. }
  215. func doTestSendMail(t *testing.T, d *Dialer, want []string, timeout bool) {
  216. testClient := &mockClient{
  217. t: t,
  218. want: want,
  219. addr: addr(d.Host, d.Port),
  220. config: d.TLSConfig,
  221. timeout: timeout,
  222. }
  223. netDialTimeout = func(network, address string, d time.Duration) (net.Conn, error) {
  224. if network != "tcp" {
  225. t.Errorf("Invalid network, got %q, want tcp", network)
  226. }
  227. if address != testClient.addr {
  228. t.Errorf("Invalid address, got %q, want %q",
  229. address, testClient.addr)
  230. }
  231. return testConn, nil
  232. }
  233. tlsClient = func(conn net.Conn, config *tls.Config) *tls.Conn {
  234. if conn != testConn {
  235. t.Errorf("Invalid conn, got %#v, want %#v", conn, testConn)
  236. }
  237. assertConfig(t, config, testClient.config)
  238. return testTLSConn
  239. }
  240. smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) {
  241. if host != testHost {
  242. t.Errorf("Invalid host, got %q, want %q", host, testHost)
  243. }
  244. return testClient, nil
  245. }
  246. if err := d.DialAndSend(getTestMessage()); err != nil {
  247. t.Error(err)
  248. }
  249. }
  250. func assertConfig(t *testing.T, got, want *tls.Config) {
  251. if want == nil {
  252. want = &tls.Config{ServerName: testHost}
  253. }
  254. if got.ServerName != want.ServerName {
  255. t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName)
  256. }
  257. if got.InsecureSkipVerify != want.InsecureSkipVerify {
  258. t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify)
  259. }
  260. }