1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- package main
- import (
- "flag"
- "fmt"
- "io"
- "io/ioutil"
- "net/http"
- "os"
- "path"
- "time"
- "github.com/pin/tftp"
- )
- var url string
- var dir string
- // readHandler is called when client starts file download from server
- func readHandler(filename string, rf io.ReaderFrom) error {
- if _, err := os.Stat(path.Join(dir, filename)); err == nil {
- file, err := os.Open(path.Join(dir, filename))
- if err != nil {
- fmt.Fprintf(os.Stderr, "%v\n", err)
- return err
- }
- fi, err := file.Stat()
- if err != nil {
- fmt.Fprintf(os.Stderr, "%v\n", err)
- return err
- }
- rf.(tftp.OutgoingTransfer).SetSize(fi.Size())
- n, err := rf.ReadFrom(file)
- if err != nil {
- fmt.Fprintf(os.Stderr, "%v\n", err)
- return err
- }
- fmt.Printf("%s %d bytes sent\n", filename, n)
- } else { // File not found locally. Proxying the request.
- fileUrl := url + "/" + filename
- resp, err := http.Get(fileUrl)
- if err != nil {
- fmt.Fprintf(os.Stderr, "%v\n", err)
- return err
- }
- defer resp.Body.Close()
- if resp.StatusCode != 200 {
- io.Copy(ioutil.Discard, resp.Body)
- return fmt.Errorf("Received status code: %d", resp.StatusCode)
- }
- rf.(tftp.OutgoingTransfer).SetSize(resp.ContentLength)
- n, err := rf.ReadFrom(resp.Body)
- if err != nil {
- fmt.Fprintf(os.Stderr, "%v\n", err)
- return err
- }
- fmt.Printf("%s %d bytes sent\n", filename, n)
- }
- return nil
- }
- func main() {
- flag.StringVar(&dir, "dir", "/var/lib/tftpboot", "The directory to serve files from. For example /var/lib/tftpboot")
- flag.StringVar(&url, "url", "http://example.com", "The URL to proxy requests to. For example http://example.com")
- flag.Parse()
- // use nil in place of handler to disable read or write operations
- s := tftp.NewServer(readHandler, nil)
- s.SetTimeout(5 * time.Second) // optional
- err := s.ListenAndServe(":69") // blocks until s.Shutdown() is called
- if err != nil {
- fmt.Fprintf(os.Stdout, "server: %v\n", err)
- os.Exit(1)
- }
- }
|