main.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package main
  2. import (
  3. "flag"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net/http"
  8. "os"
  9. "path"
  10. "time"
  11. "github.com/pin/tftp"
  12. )
  13. var url string
  14. var dir string
  15. // readHandler is called when client starts file download from server
  16. func readHandler(filename string, rf io.ReaderFrom) error {
  17. if _, err := os.Stat(path.Join(dir, filename)); err == nil {
  18. file, err := os.Open(path.Join(dir, filename))
  19. if err != nil {
  20. fmt.Fprintf(os.Stderr, "%v\n", err)
  21. return err
  22. }
  23. fi, err := file.Stat()
  24. if err != nil {
  25. fmt.Fprintf(os.Stderr, "%v\n", err)
  26. return err
  27. }
  28. rf.(tftp.OutgoingTransfer).SetSize(fi.Size())
  29. n, err := rf.ReadFrom(file)
  30. if err != nil {
  31. fmt.Fprintf(os.Stderr, "%v\n", err)
  32. return err
  33. }
  34. fmt.Printf("%s %d bytes sent\n", filename, n)
  35. } else { // File not found locally. Proxying the request.
  36. fileUrl := url + "/" + filename
  37. resp, err := http.Get(fileUrl)
  38. if err != nil {
  39. fmt.Fprintf(os.Stderr, "%v\n", err)
  40. return err
  41. }
  42. defer resp.Body.Close()
  43. if resp.StatusCode != 200 {
  44. io.Copy(ioutil.Discard, resp.Body)
  45. return fmt.Errorf("Received status code: %d", resp.StatusCode)
  46. }
  47. rf.(tftp.OutgoingTransfer).SetSize(resp.ContentLength)
  48. n, err := rf.ReadFrom(resp.Body)
  49. if err != nil {
  50. fmt.Fprintf(os.Stderr, "%v\n", err)
  51. return err
  52. }
  53. fmt.Printf("%s %d bytes sent\n", filename, n)
  54. }
  55. return nil
  56. }
  57. func main() {
  58. flag.StringVar(&dir, "dir", "/var/lib/tftpboot", "The directory to serve files from. For example /var/lib/tftpboot")
  59. flag.StringVar(&url, "url", "http://example.com", "The URL to proxy requests to. For example http://example.com")
  60. flag.Parse()
  61. // use nil in place of handler to disable read or write operations
  62. s := tftp.NewServer(readHandler, nil)
  63. s.SetTimeout(5 * time.Second) // optional
  64. err := s.ListenAndServe(":69") // blocks until s.Shutdown() is called
  65. if err != nil {
  66. fmt.Fprintf(os.Stdout, "server: %v\n", err)
  67. os.Exit(1)
  68. }
  69. }