diff --git a/py/import.go b/py/import.go index 5ca6598a..8adc8568 100644 --- a/py/import.go +++ b/py/import.go @@ -108,7 +108,9 @@ func ImportModuleLevelObject(ctx Context, name string, globals, locals StringDic } if fromFile, ok := globals["__file__"]; ok { - opts.CurDir = filepath.Dir(string(fromFile.(String))) + if fromFileStr, ok := fromFile.(String); ok { + opts.CurDir = filepath.Dir(string(fromFileStr)) + } } module, err := RunFile(ctx, srcPathname, opts, name) @@ -344,14 +346,42 @@ func BuiltinImport(ctx Context, self Object, args Tuple, kwargs StringDict, curr var globals Object = currentGlobal var locals Object = NewStringDict() var fromlist Object = Tuple{} + var fromlistTuple Tuple var level Object = Int(0) err := ParseTupleAndKeywords(args, kwargs, "U|OOOi:__import__", kwlist, &name, &globals, &locals, &fromlist, &level) if err != nil { return nil, err } - if fromlist == None { - fromlist = Tuple{} + levelObj, ok := level.(Int) + if !ok { + return nil, ExceptionNewf(TypeError, "__import__() argument 5 must be int, not %s", level.Type().Name) + } + levelInt, err := levelObj.GoInt() + if err != nil { + return nil, err + } + + globalsDict, ok := globals.(StringDict) + if !ok { + if levelInt > 0 { + return nil, ExceptionNewf(TypeError, "globals must be a dict") + } + globalsDict = StringDict{} } - return ImportModuleLevelObject(ctx, string(name.(String)), globals.(StringDict), locals.(StringDict), fromlist.(Tuple), int(level.(Int))) + + localsDict, ok := locals.(StringDict) + if !ok { + localsDict = StringDict{} + } + + fromlistTuple = Tuple{} + if fromlist != None { + fromlistTuple, err = SequenceTuple(fromlist) + if err != nil { + return nil, err + } + } + + return ImportModuleLevelObject(ctx, string(name.(String)), globalsDict, localsDict, fromlistTuple, levelInt) } diff --git a/stdlib/builtin/tests/builtin.py b/stdlib/builtin/tests/builtin.py index 77a831a2..36cb8ff2 100644 --- a/stdlib/builtin/tests/builtin.py +++ b/stdlib/builtin/tests/builtin.py @@ -501,6 +501,26 @@ class C: pass assert lib.libfn() == 42 assert lib.libvar == 43 assert lib.libclass().method() == 44 +lib = __import__("lib", {}, {}, [""]) +assert lib.libfn() == 42 +ok = False +try: + __import__("lib", {}, {}, 1) +except TypeError: + ok = True +assert ok, "TypeError not raised" +lib = __import__("lib", 1, {}, [""]) +assert lib.libfn() == 42 +ok = False +try: + __import__("lib", 1, {}, [""], 1) +except TypeError as e: + if e.args[0] != "globals must be a dict": + raise + ok = True +assert ok, "TypeError not raised" +lib = __import__("lib", {"__file__": 1}, {}, [""]) +assert lib.libfn() == 42 doc="input" import sys