LCOV - code coverage report
Current view: top level - adapter/http/middleware - cors.go Coverage Total Hit
Test: coverage.lcov Lines: 30.1 % 136 41
Test Date: 2026-04-14 06:42:22 Functions: - 0 0

            Line data    Source code
       1              : package middleware
       2              : 
       3              : import (
       4              :         "net/http"
       5              :         "os"
       6              :         "strconv"
       7              :         "strings"
       8              :         "time"
       9              : 
      10              :         "github.com/gin-gonic/gin"
      11              : )
      12              : 
      13              : // CORSConfig は CORS ミドルウェアの動作を制御します
      14              : type CORSConfig struct {
      15              :         AllowOrigins     []string
      16              :         AllowMethods     []string
      17              :         AllowHeaders     []string
      18              :         ExposeHeaders    []string
      19              :         AllowCredentials bool
      20              :         MaxAge           time.Duration
      21              : }
      22              : 
      23              : // DefaultCORSConfig は CORS レスポンスヘッダの設定を保持する構造体です。
      24            1 : func DefaultCORSConfig() CORSConfig {
      25            1 :         return CORSConfig{
      26            1 :                 AllowOrigins:     envCSV("CORS_ALLOW_ORIGINS", []string{"http://resume.local", "http://www.resume.local", "http://localhost:3000", "http://localhost:5173"}),
      27            1 :                 AllowMethods:     envCSV("CORS_ALLOW_METHODS", []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}),
      28            1 :                 AllowHeaders:     envCSV("CORS_ALLOW_HEADERS", []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With", "X-Request-ID"}),
      29            1 :                 ExposeHeaders:    envCSV("CORS_EXPOSE_HEADERS", []string{"X-Request-ID"}),
      30            1 :                 AllowCredentials: envBool("CORS_ALLOW_CREDENTIALS", false),
      31            1 :                 MaxAge:           envDuration("CORS_MAX_AGE", 12*time.Hour),
      32            1 :         }
      33            1 : }
      34              : 
      35              : // Cors NewCORS returns a Gin middleware that sets Cors headers.
      36            1 : func Cors(cfg CORSConfig) gin.HandlerFunc {
      37            1 :         allowAll := contains(cfg.AllowOrigins, "*")
      38            1 :         allowMethods := strings.Join(cfg.AllowMethods, ",")
      39            1 :         fallbackAllowHeaders := strings.Join(cfg.AllowHeaders, ",")
      40            1 :         expose := strings.Join(cfg.ExposeHeaders, ",")
      41            1 : 
      42            2 :         return func(c *gin.Context) {
      43            1 :                 origin := c.Request.Header.Get("Origin")
      44            2 :                 if origin == "" {
      45            1 :                         c.Next()
      46            1 :                         return
      47            1 :                 }
      48              :                 // 許可判定(完全一致 or "*")
      49            0 :                 originAllowed := allowAll || containsFold(cfg.AllowOrigins, origin)
      50            0 :                 if !originAllowed {
      51            0 :                         c.Next()
      52            0 :                         return
      53            0 :                 }
      54              : 
      55              :                 // ここで“常時”CORSをセット(後で上書きされても WriteHeader/Write でもう一度付ける)
      56            0 :                 setBase := func() {
      57            0 :                         h := c.Writer.Header()
      58            0 :                         // withCredentials=true の場合は "*" を使えない
      59            0 :                         if cfg.AllowCredentials || !allowAll {
      60            0 :                                 h.Set("Access-Control-Allow-Origin", origin) // 単一値
      61            0 :                         } else {
      62            0 :                                 h.Set("Access-Control-Allow-Origin", "*")
      63            0 :                         }
      64            0 :                         if cfg.AllowCredentials {
      65            0 :                                 h.Set("Access-Control-Allow-Credentials", "true")
      66            0 :                         }
      67            0 :                         if expose != "" {
      68            0 :                                 h.Set("Access-Control-Expose-Headers", expose)
      69            0 :                         }
      70            0 :                         addVary(c, "Origin")
      71              :                 }
      72            0 :                 setBase()
      73            0 : 
      74            0 :                 // プリフライトはここで完結
      75            0 :                 if c.Request.Method == http.MethodOptions &&
      76            0 :                         c.GetHeader("Access-Control-Request-Method") != "" {
      77            0 : 
      78            0 :                         reqHdr := c.GetHeader("Access-Control-Request-Headers")
      79            0 :                         if reqHdr == "" {
      80            0 :                                 reqHdr = fallbackAllowHeaders
      81            0 :                         }
      82            0 :                         h := c.Writer.Header()
      83            0 :                         h.Set("Access-Control-Allow-Methods", allowMethods)
      84            0 :                         h.Set("Access-Control-Allow-Headers", reqHdr)
      85            0 :                         if cfg.MaxAge > 0 {
      86            0 :                                 h.Set("Access-Control-Max-Age", strconv.Itoa(int(cfg.MaxAge/time.Second)))
      87            0 :                         }
      88            0 :                         addVary(c, "Access-Control-Request-Method")
      89            0 :                         addVary(c, "Access-Control-Request-Headers")
      90            0 : 
      91            0 :                         c.AbortWithStatus(http.StatusNoContent)
      92            0 :                         return
      93              :                 }
      94              : 
      95              :                 // --- ここがキモ:最終書き込み直前にもう一度 CORS を強制注入 ---
      96            0 :                 origWriter := c.Writer
      97            0 :                 c.Writer = &corsResponseWriter{
      98            0 :                         ResponseWriter: origWriter,
      99            0 :                         ensure:         setBase,
     100            0 :                 }
     101            0 : 
     102            0 :                 c.Next()
     103              :         }
     104              : }
     105              : 
     106              : // gin.ResponseWriter をラップし、WriteHeader/Write の直前で ensure() を呼ぶ
     107              : type corsResponseWriter struct {
     108              :         gin.ResponseWriter
     109              :         ensure func()
     110              : }
     111              : 
     112            0 : func (w *corsResponseWriter) WriteHeader(code int) {
     113            0 :         w.ensure()
     114            0 :         w.ResponseWriter.WriteHeader(code)
     115            0 : }
     116            0 : func (w *corsResponseWriter) Write(b []byte) (int, error) {
     117            0 :         // header 未送出のまま Write が先に呼ばれる場合に備えて
     118            0 :         w.ensure()
     119            0 :         return w.ResponseWriter.Write(b)
     120            0 : }
     121              : 
     122            0 : func addVary(c *gin.Context, v string) {
     123            0 :         const key = "Vary"
     124            0 :         cur := c.Writer.Header().Get(key)
     125            0 :         if cur == "" {
     126            0 :                 c.Header(key, v)
     127            0 :                 return
     128            0 :         }
     129              :         // 既に含まれていれば重複させない
     130            0 :         for _, part := range strings.Split(cur, ",") {
     131            0 :                 if strings.EqualFold(strings.TrimSpace(part), v) {
     132            0 :                         return
     133            0 :                 }
     134              :         }
     135            0 :         c.Header(key, cur+", "+v)
     136              : }
     137              : 
     138            1 : func envCSV(key string, def []string) []string {
     139            1 :         raw := strings.TrimSpace(os.Getenv(key))
     140            2 :         if raw == "" {
     141            1 :                 return def
     142            1 :         }
     143            0 :         parts := strings.Split(raw, ",")
     144            0 :         out := make([]string, 0, len(parts))
     145            0 :         for _, p := range parts {
     146            0 :                 if s := strings.TrimSpace(p); s != "" {
     147            0 :                         out = append(out, s)
     148            0 :                 }
     149              :         }
     150            0 :         return out
     151              : }
     152              : 
     153            1 : func envBool(key string, def bool) bool {
     154            1 :         raw := strings.TrimSpace(strings.ToLower(os.Getenv(key)))
     155            2 :         if raw == "" {
     156            1 :                 return def
     157            1 :         }
     158            0 :         switch raw {
     159            0 :         case "1", "true", "t", "yes", "y":
     160            0 :                 return true
     161            0 :         case "0", "false", "f", "no", "n":
     162            0 :                 return false
     163            0 :         default:
     164            0 :                 return def
     165              :         }
     166              : }
     167              : 
     168            1 : func envDuration(key string, def time.Duration) time.Duration {
     169            1 :         raw := strings.TrimSpace(os.Getenv(key))
     170            2 :         if raw == "" {
     171            1 :                 return def
     172            1 :         }
     173              :         // time.ParseDuration 形式(例: "2h", "30m")
     174            0 :         if d, err := time.ParseDuration(raw); err == nil {
     175            0 :                 return d
     176            0 :         }
     177            0 :         return def
     178              : }
     179              : 
     180            1 : func contains(list []string, v string) bool {
     181            2 :         for _, s := range list {
     182            1 :                 if s == v {
     183            0 :                         return true
     184            0 :                 }
     185              :         }
     186            1 :         return false
     187              : }
     188              : 
     189            0 : func containsFold(list []string, v string) bool {
     190            0 :         for _, s := range list {
     191            0 :                 if strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(v)) {
     192            0 :                         return true
     193            0 :                 }
     194              :         }
     195            0 :         return false
     196              : }
     197              : 
     198              : //func formatSeconds(d time.Duration) string {
     199              : //      sec := int(d / time.Second)
     200              : //      return strconvItoa(sec)
     201              : //}
     202              : 
     203              : // ループ回避のため最小限
     204              : //func strconvItoa(i int) string {
     205              : //      // 依存を減らすために簡易実装でもOKだが、通常は strconv.Itoa を使う
     206              : //      // ここでは標準を使います
     207              : //      return strconv.Itoa(i)
     208              : //}
        

Generated by: LCOV version 2.3.1-1