diff --git a/app/router/condition_geoip_test.go b/app/router/condition_geoip_test.go index 63bd222e..07f40b83 100644 --- a/app/router/condition_geoip_test.go +++ b/app/router/condition_geoip_test.go @@ -1,6 +1,7 @@ package router_test import ( + "fmt" "os" "path/filepath" "testing" @@ -13,16 +14,25 @@ import ( "google.golang.org/protobuf/proto" ) -func init() { - wd, err := os.Getwd() - common.Must(err) +func getAssetPath(file string) (string, error) { + path := platform.GetAssetLocation(file) + _, err := os.Stat(path) + if os.IsNotExist(err) { + path := filepath.Join("..", "..", "resources", file) + _, err := os.Stat(path) + if os.IsNotExist(err) { + return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file) + } + if err != nil { + return "", fmt.Errorf("can't stat %s: %v", path, err) + } + return path, nil + } + if err != nil { + return "", fmt.Errorf("can't stat %s: %v", path, err) + } - if _, err := os.Stat(platform.GetAssetLocation("geoip.dat")); err != nil && os.IsNotExist(err) { - common.Must(filesystem.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(wd, "..", "..", "resources", "geoip.dat"))) - } - if _, err := os.Stat(platform.GetAssetLocation("geosite.dat")); err != nil && os.IsNotExist(err) { - common.Must(filesystem.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(wd, "..", "..", "resources", "geosite.dat"))) - } + return path, nil } func TestGeoIPMatcherContainer(t *testing.T) { @@ -217,10 +227,15 @@ func TestGeoIPMatcher6US(t *testing.T) { } func loadGeoIP(country string) ([]*router.CIDR, error) { - geoipBytes, err := filesystem.ReadAsset("geoip.dat") + path, err := getAssetPath("geoip.dat") if err != nil { return nil, err } + geoipBytes, err := filesystem.ReadFile(path) + if err != nil { + return nil, err + } + var geoipList router.GeoIPList if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil { return nil, err diff --git a/app/router/condition_test.go b/app/router/condition_test.go index 75ee86dd..97d05db9 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -1,8 +1,6 @@ package router_test import ( - "os" - "path/filepath" "strconv" "testing" @@ -10,7 +8,6 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform" "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol/http" @@ -20,18 +17,6 @@ import ( "google.golang.org/protobuf/proto" ) -func init() { - wd, err := os.Getwd() - common.Must(err) - - if _, err := os.Stat(platform.GetAssetLocation("geoip.dat")); err != nil && os.IsNotExist(err) { - common.Must(filesystem.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(wd, "..", "..", "release", "config", "geoip.dat"))) - } - if _, err := os.Stat(platform.GetAssetLocation("geosite.dat")); err != nil && os.IsNotExist(err) { - common.Must(filesystem.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(wd, "..", "..", "release", "config", "geosite.dat"))) - } -} - func withBackground() routing.Context { return &routing_session.Context{} } @@ -316,10 +301,15 @@ func TestRoutingRule(t *testing.T) { } func loadGeoSite(country string) ([]*Domain, error) { - geositeBytes, err := filesystem.ReadAsset("geosite.dat") + path, err := getAssetPath("geosite.dat") if err != nil { return nil, err } + geositeBytes, err := filesystem.ReadFile(path) + if err != nil { + return nil, err + } + var geositeList GeoSiteList if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil { return nil, err