Line data Source code
1 : package middleware
2 :
3 : import (
4 : "net/http"
5 : "os"
6 : "strconv"
7 : "strings"
8 : "time"
9 :
10 : "github.com/gin-gonic/gin"
11 : )
12 :
13 : // CORSConfig は CORS ミドルウェアの動作を制御します
14 : type CORSConfig struct {
15 : AllowOrigins []string
16 : AllowMethods []string
17 : AllowHeaders []string
18 : ExposeHeaders []string
19 : AllowCredentials bool
20 : MaxAge time.Duration
21 : }
22 :
23 : // DefaultCORSConfig は CORS レスポンスヘッダの設定を保持する構造体です。
24 1 : func DefaultCORSConfig() CORSConfig {
25 1 : return CORSConfig{
26 1 : AllowOrigins: envCSV("CORS_ALLOW_ORIGINS", []string{"http://resume.local", "http://www.resume.local", "http://localhost:3000", "http://localhost:5173"}),
27 1 : AllowMethods: envCSV("CORS_ALLOW_METHODS", []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}),
28 1 : AllowHeaders: envCSV("CORS_ALLOW_HEADERS", []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With", "X-Request-ID"}),
29 1 : ExposeHeaders: envCSV("CORS_EXPOSE_HEADERS", []string{"X-Request-ID"}),
30 1 : AllowCredentials: envBool("CORS_ALLOW_CREDENTIALS", false),
31 1 : MaxAge: envDuration("CORS_MAX_AGE", 12*time.Hour),
32 1 : }
33 1 : }
34 :
35 : // Cors NewCORS returns a Gin middleware that sets Cors headers.
36 1 : func Cors(cfg CORSConfig) gin.HandlerFunc {
37 1 : allowAll := contains(cfg.AllowOrigins, "*")
38 1 : allowMethods := strings.Join(cfg.AllowMethods, ",")
39 1 : fallbackAllowHeaders := strings.Join(cfg.AllowHeaders, ",")
40 1 : expose := strings.Join(cfg.ExposeHeaders, ",")
41 1 :
42 2 : return func(c *gin.Context) {
43 1 : origin := c.Request.Header.Get("Origin")
44 2 : if origin == "" {
45 1 : c.Next()
46 1 : return
47 1 : }
48 : // 許可判定(完全一致 or "*")
49 0 : originAllowed := allowAll || containsFold(cfg.AllowOrigins, origin)
50 0 : if !originAllowed {
51 0 : c.Next()
52 0 : return
53 0 : }
54 :
55 : // ここで“常時”CORSをセット(後で上書きされても WriteHeader/Write でもう一度付ける)
56 0 : setBase := func() {
57 0 : h := c.Writer.Header()
58 0 : // withCredentials=true の場合は "*" を使えない
59 0 : if cfg.AllowCredentials || !allowAll {
60 0 : h.Set("Access-Control-Allow-Origin", origin) // 単一値
61 0 : } else {
62 0 : h.Set("Access-Control-Allow-Origin", "*")
63 0 : }
64 0 : if cfg.AllowCredentials {
65 0 : h.Set("Access-Control-Allow-Credentials", "true")
66 0 : }
67 0 : if expose != "" {
68 0 : h.Set("Access-Control-Expose-Headers", expose)
69 0 : }
70 0 : addVary(c, "Origin")
71 : }
72 0 : setBase()
73 0 :
74 0 : // プリフライトはここで完結
75 0 : if c.Request.Method == http.MethodOptions &&
76 0 : c.GetHeader("Access-Control-Request-Method") != "" {
77 0 :
78 0 : reqHdr := c.GetHeader("Access-Control-Request-Headers")
79 0 : if reqHdr == "" {
80 0 : reqHdr = fallbackAllowHeaders
81 0 : }
82 0 : h := c.Writer.Header()
83 0 : h.Set("Access-Control-Allow-Methods", allowMethods)
84 0 : h.Set("Access-Control-Allow-Headers", reqHdr)
85 0 : if cfg.MaxAge > 0 {
86 0 : h.Set("Access-Control-Max-Age", strconv.Itoa(int(cfg.MaxAge/time.Second)))
87 0 : }
88 0 : addVary(c, "Access-Control-Request-Method")
89 0 : addVary(c, "Access-Control-Request-Headers")
90 0 :
91 0 : c.AbortWithStatus(http.StatusNoContent)
92 0 : return
93 : }
94 :
95 : // --- ここがキモ:最終書き込み直前にもう一度 CORS を強制注入 ---
96 0 : origWriter := c.Writer
97 0 : c.Writer = &corsResponseWriter{
98 0 : ResponseWriter: origWriter,
99 0 : ensure: setBase,
100 0 : }
101 0 :
102 0 : c.Next()
103 : }
104 : }
105 :
106 : // gin.ResponseWriter をラップし、WriteHeader/Write の直前で ensure() を呼ぶ
107 : type corsResponseWriter struct {
108 : gin.ResponseWriter
109 : ensure func()
110 : }
111 :
112 0 : func (w *corsResponseWriter) WriteHeader(code int) {
113 0 : w.ensure()
114 0 : w.ResponseWriter.WriteHeader(code)
115 0 : }
116 0 : func (w *corsResponseWriter) Write(b []byte) (int, error) {
117 0 : // header 未送出のまま Write が先に呼ばれる場合に備えて
118 0 : w.ensure()
119 0 : return w.ResponseWriter.Write(b)
120 0 : }
121 :
122 0 : func addVary(c *gin.Context, v string) {
123 0 : const key = "Vary"
124 0 : cur := c.Writer.Header().Get(key)
125 0 : if cur == "" {
126 0 : c.Header(key, v)
127 0 : return
128 0 : }
129 : // 既に含まれていれば重複させない
130 0 : for _, part := range strings.Split(cur, ",") {
131 0 : if strings.EqualFold(strings.TrimSpace(part), v) {
132 0 : return
133 0 : }
134 : }
135 0 : c.Header(key, cur+", "+v)
136 : }
137 :
138 1 : func envCSV(key string, def []string) []string {
139 1 : raw := strings.TrimSpace(os.Getenv(key))
140 2 : if raw == "" {
141 1 : return def
142 1 : }
143 0 : parts := strings.Split(raw, ",")
144 0 : out := make([]string, 0, len(parts))
145 0 : for _, p := range parts {
146 0 : if s := strings.TrimSpace(p); s != "" {
147 0 : out = append(out, s)
148 0 : }
149 : }
150 0 : return out
151 : }
152 :
153 1 : func envBool(key string, def bool) bool {
154 1 : raw := strings.TrimSpace(strings.ToLower(os.Getenv(key)))
155 2 : if raw == "" {
156 1 : return def
157 1 : }
158 0 : switch raw {
159 0 : case "1", "true", "t", "yes", "y":
160 0 : return true
161 0 : case "0", "false", "f", "no", "n":
162 0 : return false
163 0 : default:
164 0 : return def
165 : }
166 : }
167 :
168 1 : func envDuration(key string, def time.Duration) time.Duration {
169 1 : raw := strings.TrimSpace(os.Getenv(key))
170 2 : if raw == "" {
171 1 : return def
172 1 : }
173 : // time.ParseDuration 形式(例: "2h", "30m")
174 0 : if d, err := time.ParseDuration(raw); err == nil {
175 0 : return d
176 0 : }
177 0 : return def
178 : }
179 :
180 1 : func contains(list []string, v string) bool {
181 2 : for _, s := range list {
182 1 : if s == v {
183 0 : return true
184 0 : }
185 : }
186 1 : return false
187 : }
188 :
189 0 : func containsFold(list []string, v string) bool {
190 0 : for _, s := range list {
191 0 : if strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(v)) {
192 0 : return true
193 0 : }
194 : }
195 0 : return false
196 : }
197 :
198 : //func formatSeconds(d time.Duration) string {
199 : // sec := int(d / time.Second)
200 : // return strconvItoa(sec)
201 : //}
202 :
203 : // ループ回避のため最小限
204 : //func strconvItoa(i int) string {
205 : // // 依存を減らすために簡易実装でもOKだが、通常は strconv.Itoa を使う
206 : // // ここでは標準を使います
207 : // return strconv.Itoa(i)
208 : //}
|