Source code

Revision control

Copy as Markdown

Other Tools

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import re
from copy import deepcopy
from collections import OrderedDict
import itertools
import ipdl.ast
import ipdl.builtin
from ipdl.cxx.ast import *
from ipdl.cxx.code import *
from ipdl.type import ActorType, UnionType, TypeVisitor, builtinHeaderIncludes
from ipdl.util import hash_str
# -----------------------------------------------------------------------------
# "Public" interface to lowering
##
class LowerToCxx:
def lower(self, tu, segmentcapacitydict):
"""returns |[ header: File ], [ cpp : File ]| representing the
lowered form of |tu|"""
# annotate the AST with IPDL/C++ IR-type stuff used later
tu.accept(_DecorateWithCxxStuff())
# Any modifications to the filename scheme here need corresponding
# modifications in the ipdl.py driver script.
name = tu.name
pheader, pcpp = File(name + ".h"), File(name + ".cpp")
_GenerateProtocolCode().lower(tu, pheader, pcpp, segmentcapacitydict)
headers = [pheader]
cpps = [pcpp]
if tu.protocol:
pname = tu.protocol.name
parentheader, parentcpp = (
File(pname + "Parent.h"),
File(pname + "Parent.cpp"),
)
_GenerateProtocolParentCode().lower(
tu, pname + "Parent", parentheader, parentcpp
)
childheader, childcpp = File(pname + "Child.h"), File(pname + "Child.cpp")
_GenerateProtocolChildCode().lower(
tu, pname + "Child", childheader, childcpp
)
headers += [parentheader, childheader]
cpps += [parentcpp, childcpp]
return headers, cpps
# -----------------------------------------------------------------------------
# Helper code
##
def hashfunc(value):
h = hash_str(value) % 2**32
if h < 0:
h += 2**32
return h
_NULL_ACTOR_ID = ExprLiteral.ZERO
_FREED_ACTOR_ID = ExprLiteral.ONE
_DISCLAIMER = Whitespace(
"""//
// Automatically generated by ipdlc.
// Edit at your own risk
//
"""
)
class _struct:
pass
def _namespacedHeaderName(name, namespaces):
pfx = "/".join([ns.name for ns in namespaces])
if pfx:
return pfx + "/" + name
else:
return name
def _ipdlhHeaderName(tu):
assert tu.filetype == "header"
return _namespacedHeaderName(tu.name, tu.namespaces)
def _protocolHeaderName(p, side=""):
if side:
side = side.title()
base = p.name + side
return _namespacedHeaderName(base, p.namespaces)
def _includeGuardMacroName(headerfile):
return re.sub(r"[./]", "_", headerfile.name)
def _includeGuardStart(headerfile):
guard = _includeGuardMacroName(headerfile)
return [CppDirective("ifndef", guard), CppDirective("define", guard)]
def _includeGuardEnd(headerfile):
guard = _includeGuardMacroName(headerfile)
return [CppDirective("endif", "// ifndef " + guard)]
def _messageStartName(ptype):
return ptype.name() + "MsgStart"
def _protocolId(ptype):
return ExprVar(_messageStartName(ptype))
def _protocolIdType():
return Type.INT32
def _actorName(pname, side):
"""|pname| is the protocol name. |side| is 'Parent' or 'Child'."""
tag = side
if not tag[0].isupper():
tag = side.title()
return pname + tag
def _actorIdType():
return Type.INT32
def _actorTypeTagType():
return Type.INT32
def _actorId(actor=None):
if actor is not None:
return ExprCall(ExprSelect(actor, "->", "Id"))
return ExprCall(ExprVar("Id"))
def _actorHId(actorhandle):
return ExprSelect(actorhandle, ".", "mId")
def _backstagePass():
return ExprCall(ExprVar("mozilla::ipc::PrivateIPDLInterface"))
def _deleteId():
return ExprVar("Msg___delete____ID")
def _deleteReplyId():
return ExprVar("Reply___delete____ID")
def _lookupListener(idexpr):
return ExprCall(ExprVar("Lookup"), args=[idexpr])
def _makeForwardDeclForQClass(clsname, quals, cls=True, struct=False):
fd = ForwardDecl(clsname, cls=cls, struct=struct)
if 0 == len(quals):
return fd
outerns = Namespace(quals[0])
innerns = outerns
for ns in quals[1:]:
tmpns = Namespace(ns)
innerns.addstmt(tmpns)
innerns = tmpns
innerns.addstmt(fd)
return outerns
def _makeForwardDeclForActor(ptype, side):
return _makeForwardDeclForQClass(
_actorName(ptype.qname.baseid, side), ptype.qname.quals
)
def _makeForwardDecl(type):
return _makeForwardDeclForQClass(type.name(), type.qname.quals)
def _putInNamespaces(cxxthing, namespaces):
"""|namespaces| is in order [ outer, ..., inner ]"""
if 0 == len(namespaces):
return cxxthing
outerns = Namespace(namespaces[0].name)
innerns = outerns
for ns in namespaces[1:]:
newns = Namespace(ns.name)
innerns.addstmt(newns)
innerns = newns
innerns.addstmt(cxxthing)
return outerns
def _sendPrefix(msgtype):
"""Prefix of the name of the C++ method that sends |msgtype|."""
return "Send"
def _recvPrefix(msgtype):
"""Prefix of the name of the C++ method that handles |msgtype|."""
return "Recv"
def _flatTypeName(ipdltype):
"""Return a 'flattened' IPDL type name that can be used as an
identifier.
E.g., |Foo[]| --> |ArrayOfFoo|."""
# NB: this logic depends heavily on what IPDL types are allowed to
# be constructed; e.g., Foo[][] is disallowed. needs to be kept in
# sync with grammar.
if not ipdltype.isIPDL():
return ipdltype.name()
if ipdltype.isArray():
return "ArrayOf" + _flatTypeName(ipdltype.basetype)
if ipdltype.isMaybe():
return "Maybe" + _flatTypeName(ipdltype.basetype)
# NotNull and UniquePtr types just assume the underlying variant name
# to avoid unnecessary noise, as eg a NotNull<T> and T should never exist
# in the same union.
if ipdltype.isNotNull() or ipdltype.isUniquePtr():
return _flatTypeName(ipdltype.basetype)
return ipdltype.name()
def _hasVisibleActor(ipdltype):
"""Return true iff a C++ decl of |ipdltype| would have an Actor* type.
For example: |Actor[]| would turn into |Array<ActorParent*>|, so this
function would return true for |Actor[]|."""
return ipdltype.isIPDL() and (
ipdltype.isActor()
or (ipdltype.hasBaseType() and _hasVisibleActor(ipdltype.basetype))
)
def _abortIfFalse(cond, msg):
return StmtExpr(
ExprCall(ExprVar("MOZ_RELEASE_ASSERT"), [cond, ExprLiteral.String(msg)])
)
def _refptr(T):
return Type("RefPtr", T=T)
def _alreadyaddrefed(T):
return Type("already_AddRefed", T=T)
def _tuple(types, const=False, ref=False):
return Type("std::tuple", T=types, const=const, ref=ref)
def _promise(resolvetype, rejecttype, tail, resolver=False):
inner = Type("Private") if resolver else None
return Type("MozPromise", T=[resolvetype, rejecttype, tail], inner=inner)
def _makePromise(returns, side, resolver=False):
if len(returns) > 1:
resolvetype = _tuple([d.bareType(side) for d in returns])
else:
resolvetype = returns[0].bareType(side)
# MozPromise is purposefully made to be exclusive only. Really, we mean it.
return _promise(
resolvetype, _ResponseRejectReason.Type(), ExprLiteral.TRUE, resolver=resolver
)
def _resolveType(returns, side):
if len(returns) > 1:
return _tuple([d.inType(side, "send") for d in returns])
return returns[0].inType(side, "send")
def _makeResolver(returns, side):
return TypeFunction([Decl(_resolveType(returns, side), "")])
def _cxxArrayType(basetype, const=False, ref=False):
return Type("nsTArray", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False)
def _cxxSpanType(basetype, const=False, ref=False):
basetype = deepcopy(basetype)
basetype.rightconst = True
return Type(
"mozilla::Span", T=basetype, const=const, ref=ref, hasimplicitcopyctor=True
)
def _cxxMaybeType(basetype, const=False, ref=False):
return Type(
"mozilla::Maybe",
T=basetype,
const=const,
ref=ref,
hasimplicitcopyctor=basetype.hasimplicitcopyctor,
)
def _cxxReadResultType(basetype, const=False, ref=False):
return Type(
"IPC::ReadResult",
T=basetype,
const=const,
ref=ref,
hasimplicitcopyctor=basetype.hasimplicitcopyctor,
)
def _cxxNotNullType(basetype, const=False, ref=False):
return Type(
"mozilla::NotNull",
T=basetype,
const=const,
ref=ref,
hasimplicitcopyctor=basetype.hasimplicitcopyctor,
)
def _cxxManagedContainerType(basetype, const=False, ref=False):
return Type(
"ManagedContainer", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False
)
def _cxxLifecycleProxyType(ptr=False):
return Type("mozilla::ipc::ActorLifecycleProxy", ptr=ptr)
def _cxxSide(side):
if side == "child":
return ExprVar("mozilla::ipc::ChildSide")
if side == "parent":
return ExprVar("mozilla::ipc::ParentSide")
assert 0
def _otherSide(side):
if side == "child":
return "parent"
if side == "parent":
return "child"
assert 0
# XXX we need to remove these and install proper error handling
def _printErrorMessage(msg):
if isinstance(msg, str):
msg = ExprLiteral.String(msg)
return StmtExpr(ExprCall(ExprVar("NS_ERROR"), args=[msg]))
def _protocolErrorBreakpoint(msg):
if isinstance(msg, str):
msg = ExprLiteral.String(msg)
return StmtExpr(
ExprCall(ExprVar("mozilla::ipc::ProtocolErrorBreakpoint"), args=[msg])
)
def _printWarningMessage(msg):
if isinstance(msg, str):
msg = ExprLiteral.String(msg)
return StmtExpr(ExprCall(ExprVar("NS_WARNING"), args=[msg]))
def _fatalError(msg):
return StmtExpr(ExprCall(ExprVar("FatalError"), args=[ExprLiteral.String(msg)]))
def _logicError(msg):
return StmtExpr(
ExprCall(ExprVar("mozilla::ipc::LogicError"), args=[ExprLiteral.String(msg)])
)
def _sentinelReadError(classname):
return StmtExpr(
ExprCall(
ExprVar("mozilla::ipc::SentinelReadError"),
args=[ExprLiteral.String(classname)],
)
)
identifierRegExp = re.compile("[a-zA-Z_][a-zA-Z0-9_]*")
def _validCxxIdentifier(name):
return identifierRegExp.fullmatch(name) is not None
# Results that IPDL-generated code returns back to *Channel code.
# Users never see these
class _Result:
@staticmethod
def Type():
return Type("Result")
Processed = ExprVar("MsgProcessed")
NotKnown = ExprVar("MsgNotKnown")
NotAllowed = ExprVar("MsgNotAllowed")
PayloadError = ExprVar("MsgPayloadError")
ProcessingError = ExprVar("MsgProcessingError")
RouteError = ExprVar("MsgRouteError")
ValuError = ExprVar("MsgValueError") # [sic]
# these |errfn*| are functions that generate code to be executed on an
# error, such as "bad actor ID". each is given a Python string
# containing a description of the error
# used in user-facing Send*() methods
def errfnSend(msg, errcode=ExprLiteral.FALSE):
return [_fatalError(msg), StmtReturn(errcode)]
def errfnSendCtor(msg):
return errfnSend(msg, errcode=ExprLiteral.NULL)
# TODO should this error handling be strengthened for dtors?
def errfnSendDtor(msg):
return [_printErrorMessage(msg), StmtReturn.FALSE]
# used in |OnMessage*()| handlers that hand in-messages off to Recv*()
# interface methods
def errfnRecv(msg, errcode=_Result.ValuError):
return [_fatalError(msg), StmtReturn(errcode)]
def errfnSentinel(rvalue=ExprLiteral.FALSE):
def inner(msg):
return [_sentinelReadError(msg), StmtReturn(rvalue)]
return inner
def errfnUnreachable(msg):
return [_logicError(msg)]
def readResultError():
return ExprCode("{}")
class _DestroyReason:
@staticmethod
def Type():
return Type("ActorDestroyReason")
Deletion = ExprVar("Deletion")
AncestorDeletion = ExprVar("AncestorDeletion")
NormalShutdown = ExprVar("NormalShutdown")
AbnormalShutdown = ExprVar("AbnormalShutdown")
FailedConstructor = ExprVar("FailedConstructor")
ManagedEndpointDropped = ExprVar("ManagedEndpointDropped")
class _ResponseRejectReason:
@staticmethod
def Type():
return Type("ResponseRejectReason")
SendError = ExprVar("ResponseRejectReason::SendError")
ChannelClosed = ExprVar("ResponseRejectReason::ChannelClosed")
HandlerRejected = ExprVar("ResponseRejectReason::HandlerRejected")
ActorDestroyed = ExprVar("ResponseRejectReason::ActorDestroyed")
# -----------------------------------------------------------------------------
# Intermediate representation (IR) nodes used during lowering
class _ConvertToCxxType(TypeVisitor):
def __init__(self, side, fq):
self.side = side
self.fq = fq
def typename(self, thing):
if self.fq:
return thing.fullname()
return thing.name()
def visitImportedCxxType(self, t):
cxxtype = Type(self.typename(t))
if t.isRefcounted():
cxxtype = _refptr(cxxtype)
return cxxtype
def visitBuiltinCType(self, b):
return Type(self.typename(b))
def visitActorType(self, a):
if self.side is None:
return Type(
"::mozilla::ipc::SideVariant",
T=[
_cxxBareType(a, "parent", self.fq),
_cxxBareType(a, "child", self.fq),
],
)
return Type(_actorName(self.typename(a.protocol), self.side), ptr=True)
def visitStructType(self, s):
return Type(self.typename(s))
def visitUnionType(self, u):
return Type(self.typename(u))
def visitArrayType(self, a):
basecxxtype = a.basetype.accept(self)
return _cxxArrayType(basecxxtype)
def visitMaybeType(self, m):
basecxxtype = m.basetype.accept(self)
return _cxxMaybeType(basecxxtype)
def visitShmemType(self, s):
return Type(self.typename(s))
def visitByteBufType(self, s):
return Type(self.typename(s))
def visitFDType(self, s):
return Type(self.typename(s))
def visitEndpointType(self, s):
return Type(self.typename(s))
def visitManagedEndpointType(self, s):
return Type(self.typename(s))
def visitUniquePtrType(self, s):
return Type(self.typename(s))
def visitNotNullType(self, n):
basecxxtype = n.basetype.accept(self)
return _cxxNotNullType(basecxxtype)
def visitProtocolType(self, p):
assert 0
def visitMessageType(self, m):
assert 0
def visitVoidType(self, v):
assert 0
def _cxxBareType(ipdltype, side, fq=False):
return ipdltype.accept(_ConvertToCxxType(side, fq))
def _cxxRefType(ipdltype, side):
t = _cxxBareType(ipdltype, side)
t.ref = True
return t
def _cxxConstRefType(ipdltype, side):
t = _cxxBareType(ipdltype, side)
if ipdltype.isIPDL() and ipdltype.isActor():
return t
if ipdltype.isIPDL() and ipdltype.isShmem():
t.ref = True
return t
if ipdltype.isIPDL() and ipdltype.isNotNull():
# If the inner type chooses to use a raw pointer, wrap that instead.
inner = _cxxConstRefType(ipdltype.basetype, side)
if inner.ptr:
t = _cxxNotNullType(inner)
return t
if ipdltype.isIPDL() and ipdltype.hasBaseType():
# Keep same constness as inner type.
inner = _cxxConstRefType(ipdltype.basetype, side)
t.const = inner.const or not inner.ref
t.ref = True
return t
if ipdltype.isCxx() and (ipdltype.isSendMoveOnly() or ipdltype.isDataMoveOnly()):
t.const = True
t.ref = True
return t
if ipdltype.isCxx() and ipdltype.isRefcounted():
# Use T* instead of const RefPtr<T>&
t = t.T
t.ptr = True
return t
t.const = True
t.ref = True
return t
def _cxxTypeNeedsMoveForSend(ipdltype, context="root", visited=None):
"""Returns `True` if serializing ipdltype requires a mutable reference, e.g.
because the underlying resource represented by the value is being
transferred to another process. This is occasionally distinct from whether
the C++ type exposes a copy constructor, such as for types which are not
cheaply copiable, but are not mutated when serialized."""
if visited is None:
visited = set()
visited.add(ipdltype)
if ipdltype.isCxx():
return ipdltype.isSendMoveOnly()
if ipdltype.isIPDL():
if ipdltype.hasBaseType():
return _cxxTypeNeedsMoveForSend(ipdltype.basetype, "wrapper", visited)
if ipdltype.isStruct() or ipdltype.isUnion():
return any(
_cxxTypeNeedsMoveForSend(t, "compound", visited)
for t in ipdltype.itercomponents()
if t not in visited
)
# For historical reasons, shmem is `const_cast` to a mutable reference
# when being stored in a struct or union (see
# `_StructField.constRefExpr` and `_UnionMember.getConstValue`), meaning
# that they do not cause the containing struct to require move for
# sending.
if ipdltype.isShmem():
return context != "compound"
return (
ipdltype.isByteBuf()
or ipdltype.isEndpoint()
or ipdltype.isManagedEndpoint()
)
return False
def _cxxTypeNeedsMoveForData(ipdltype, context="root", visited=None):
"""Returns `True` if the bare C++ type corresponding to ipdltype does not
satisfy std::is_copy_constructible_v<T>. All C++ types supported by IPDL
must support std::is_move_constructible_v<T>, so non-movable types must be
passed behind a `UniquePtr`."""
if visited is None:
visited = set()
visited.add(ipdltype)
if ipdltype.isCxx():
return ipdltype.isDataMoveOnly()
if ipdltype.isIPDL():
if ipdltype.isUniquePtr():
return True
# When nested within a maybe or array, arrays are no longer copyable.
if context == "wrapper" and ipdltype.isArray():
return True
if ipdltype.hasBaseType():
return _cxxTypeNeedsMoveForData(ipdltype.basetype, "wrapper", visited)
if ipdltype.isStruct() or ipdltype.isUnion():
return any(
_cxxTypeNeedsMoveForData(t, "compound", visited)
for t in ipdltype.itercomponents()
if t not in visited
)
return (
ipdltype.isByteBuf()
or ipdltype.isEndpoint()
or ipdltype.isManagedEndpoint()
)
return False
def _cxxTypeCanMove(ipdltype):
return not (ipdltype.isIPDL() and ipdltype.isActor())
def _cxxForceMoveRefType(ipdltype, side):
assert _cxxTypeCanMove(ipdltype)
t = _cxxBareType(ipdltype, side)
t.rvalref = True
return t
def _cxxPtrToType(ipdltype, side):
t = _cxxBareType(ipdltype, side)
if ipdltype.isIPDL() and ipdltype.isActor() and side is not None:
t.ptr = False
t.ptrptr = True
return t
t.ptr = True
return t
def _cxxConstPtrToType(ipdltype, side):
t = _cxxBareType(ipdltype, side)
if ipdltype.isIPDL() and ipdltype.isActor() and side is not None:
t.ptr = False
t.ptrconstptr = True
return t
t.const = True
t.ptr = True
return t
def _cxxInType(ipdltype, side, direction):
t = _cxxBareType(ipdltype, side)
if ipdltype.isIPDL() and ipdltype.isActor():
return t
if ipdltype.isIPDL() and ipdltype.isNotNull():
# If the inner type chooses to use a raw pointer, wrap that instead.
inner = _cxxInType(ipdltype.basetype, side, direction)
if inner.ptr:
t = _cxxNotNullType(inner)
return t
if _cxxTypeNeedsMoveForSend(ipdltype):
t.rvalref = True
return t
if ipdltype.isCxx():
if ipdltype.isRefcounted():
# Use T* instead of const RefPtr<T>&
t = t.T
t.ptr = True
return t
if ipdltype.name() == "nsCString":
t = Type("nsACString")
if ipdltype.name() == "nsString":
t = Type("nsAString")
# Use Span<T const> rather than nsTArray<T> for array types which aren't
# `_cxxTypeNeedsMoveForSend`. This is only done for the "send" side, and not
# for recv signatures.
if direction == "send" and ipdltype.isIPDL() and ipdltype.isArray():
inner = _cxxBareType(ipdltype.basetype, side)
return _cxxSpanType(inner)
t.const = True
t.ref = True
return t
def _allocMethod(ptype, side):
return "Alloc" + ptype.name() + side.title()
def _deallocMethod(ptype, side):
return "Dealloc" + ptype.name() + side.title()
##
# A _HybridDecl straddles IPDL and C++ decls. It knows which C++
# types correspond to which IPDL types, and it also knows how
# serialize and deserialize "special" IPDL C++ types.
##
class _HybridDecl:
"""A hybrid decl stores both an IPDL type and all the C++ type
info needed by later passes, along with a basic name for the decl."""
def __init__(self, ipdltype, name, attributes={}):
self.ipdltype = ipdltype
self.name = name
self.attributes = attributes
def var(self):
return ExprVar(self.name)
def bareType(self, side, fq=False):
"""Return this decl's unqualified C++ type."""
return _cxxBareType(self.ipdltype, side, fq=fq)
def refType(self, side):
"""Return this decl's C++ type as a 'reference' type, which is not
necessarily a C++ reference."""
return _cxxRefType(self.ipdltype, side)
def constRefType(self, side):
"""Return this decl's C++ type as a const, 'reference' type."""
return _cxxConstRefType(self.ipdltype, side)
def ptrToType(self, side):
return _cxxPtrToType(self.ipdltype, side)
def constPtrToType(self, side):
return _cxxConstPtrToType(self.ipdltype, side)
def inType(self, side, direction):
"""Return this decl's C++ Type with sending inparam semantics."""
return _cxxInType(self.ipdltype, side, direction)
def outType(self, side):
"""Return this decl's C++ Type with outparam semantics."""
t = self.bareType(side)
if self.ipdltype.isIPDL() and self.ipdltype.isActor():
t.ptr = False
t.ptrptr = True
return t
t.ptr = True
return t
def forceMoveType(self, side):
"""Return this decl's C++ Type with forced move semantics."""
assert _cxxTypeCanMove(self.ipdltype)
return _cxxForceMoveRefType(self.ipdltype, side)
# --------------------------------------------------
class HasFQName:
def fqClassName(self):
return self.decl.type.fullname()
class _CompoundTypeComponent(_HybridDecl):
# @override the following methods to make the side argument optional.
def bareType(self, side=None, fq=False):
return _HybridDecl.bareType(self, side, fq=fq)
def refType(self, side=None):
return _HybridDecl.refType(self, side)
def constRefType(self, side=None):
return _HybridDecl.constRefType(self, side)
def ptrToType(self, side=None):
return _HybridDecl.ptrToType(self, side)
def constPtrToType(self, side=None):
return _HybridDecl.constPtrToType(self, side)
def forceMoveType(self, side=None):
return _HybridDecl.forceMoveType(self, side)
class StructDecl(ipdl.ast.StructDecl, HasFQName):
def fields_ipdl_order(self):
for f in self.fields:
yield f
def fields_member_order(self):
assert len(self.packed_field_order) == len(self.fields)
for i in self.packed_field_order:
yield self.fields[i]
@staticmethod
def upgrade(structDecl):
assert isinstance(structDecl, ipdl.ast.StructDecl)
structDecl.__class__ = StructDecl
class _StructField(_CompoundTypeComponent):
def __init__(self, ipdltype, name, sd):
self.basename = name
_CompoundTypeComponent.__init__(self, ipdltype, name)
def getMethod(self, thisexpr=None, sel="."):
meth = self.var()
if thisexpr is not None:
return ExprSelect(thisexpr, sel, meth.name)
return meth
def refExpr(self, thisexpr=None):
ref = self.memberVar()
if thisexpr is not None:
ref = ExprSelect(thisexpr, ".", ref.name)
return ref
def constRefExpr(self, thisexpr=None):
# sigh, gross hack
refexpr = self.refExpr(thisexpr)
if "Shmem" == self.ipdltype.name():
refexpr = ExprCast(refexpr, Type("Shmem", ref=True), const=True)
return refexpr
def argVar(self):
return ExprVar("_" + self.name)
def memberVar(self):
return ExprVar(self.name + "_")
class UnionDecl(ipdl.ast.UnionDecl, HasFQName):
def callType(self, var=None):
func = ExprVar("type")
if var is not None:
func = ExprSelect(var, ".", func.name)
return ExprCall(func)
@staticmethod
def upgrade(unionDecl):
assert isinstance(unionDecl, ipdl.ast.UnionDecl)
unionDecl.__class__ = UnionDecl
class _UnionMember(_CompoundTypeComponent):
"""Not in the AFL sense, but rather a member (e.g. |int;|) of an
IPDL union type."""
def __init__(self, ipdltype, ud):
flatname = _flatTypeName(ipdltype)
assert _validCxxIdentifier(flatname)
_CompoundTypeComponent.__init__(self, ipdltype, "V" + flatname)
self.flattypename = flatname
# To create a finite object with a mutually recursive type, a union must
# be present somewhere in the recursive loop. Because of that we only
# need to care about introducing indirections inside unions.
self.recursive = ud.decl.type.mutuallyRecursiveWith(ipdltype)
def enum(self):
return "T" + self.flattypename
def enumvar(self):
return ExprVar(self.enum())
def internalType(self):
if self.recursive:
return self.ptrToType()
else:
return self.bareType()
def unionType(self):
"""Type used for storage in generated C union decl."""
if self.recursive:
return self.ptrToType()
else:
return Type("mozilla::AlignedStorage2", T=self.internalType())
def unionValue(self):
# NB: knows that Union's storage C union is named |mValue|
return ExprSelect(ExprVar("mValue"), ".", self.name)
def typedef(self):
return self.flattypename + "__tdef"
def callGetConstPtr(self):
"""Return an expression of type self.constptrToSelfType()"""
return ExprCall(ExprVar(self.getConstPtrName()))
def callGetPtr(self):
"""Return an expression of type self.ptrToSelfType()"""
return ExprCall(ExprVar(self.getPtrName()))
def callCtor(self, expr=None):
assert not isinstance(expr, list)
if expr is None:
args = None
elif (
self.ipdltype.isIPDL()
and self.ipdltype.isArray()
and not isinstance(expr, ExprMove)
):
args = [ExprCall(ExprSelect(expr, ".", "Clone"), args=[])]
else:
args = [expr]
if self.recursive:
return ExprAssn(self.callGetPtr(), ExprNew(self.bareType(), args=args))
else:
return ExprNew(
self.bareType(),
args=args,
newargs=[ExprVar("mozilla::KnownNotNull"), self.callGetPtr()],
)
def callDtor(self):
if self.recursive:
return ExprDelete(self.callGetPtr())
else:
return ExprCall(ExprSelect(self.callGetPtr(), "->", "~" + self.typedef()))
def getTypeName(self):
return "get_" + self.flattypename
def getConstTypeName(self):
return "get_" + self.flattypename
def getOtherTypeName(self):
return "get_" + self.otherflattypename
def getPtrName(self):
return "ptr_" + self.flattypename
def getConstPtrName(self):
return "constptr_" + self.flattypename
def ptrToSelfExpr(self):
"""|*ptrToSelfExpr()| has type |self.bareType()|"""
v = self.unionValue()
if self.recursive:
return v
else:
return ExprCall(ExprSelect(v, ".", "addr"))
def constptrToSelfExpr(self):
"""|*constptrToSelfExpr()| has type |self.constType()|"""
v = self.unionValue()
if self.recursive:
return v
return ExprCall(ExprSelect(v, ".", "addr"))
def ptrToInternalType(self):
t = self.ptrToType()
if self.recursive:
t.ref = True
return t
def defaultValue(self, fq=False):
# Use the default constructor for any class that does not have an
# implicit copy constructor.
if not self.bareType().hasimplicitcopyctor:
return None
if self.ipdltype.isIPDL() and self.ipdltype.isActor():
return ExprLiteral.NULL
# XXX sneaky here, maybe need ExprCtor()?
return ExprCall(self.bareType(fq=fq))
def getConstValue(self):
v = ExprDeref(self.callGetConstPtr())
# sigh
if "Shmem" == self.ipdltype.name():
v = ExprCast(v, Type("Shmem", ref=True), const=True)
return v
# --------------------------------------------------
class MessageDecl(ipdl.ast.MessageDecl):
def baseName(self):
return self.name
def recvMethod(self):
name = _recvPrefix(self.decl.type) + self.baseName()
if self.decl.type.isCtor():
name += "Constructor"
return name
def sendMethod(self):
name = _sendPrefix(self.decl.type) + self.baseName()
if self.decl.type.isCtor():
name += "Constructor"
return name
def hasReply(self):
return (
self.decl.type.hasReply()
or self.decl.type.isCtor()
or self.decl.type.isDtor()
)
def hasAsyncReturns(self):
return self.decl.type.isAsync() and self.returns
def msgCtorFunc(self):
return "Msg_%s" % (self.decl.progname)
def prettyMsgName(self, pfx=""):
return pfx + self.msgCtorFunc()
def pqMsgCtorFunc(self):
return "%s::%s" % (self.namespace, self.msgCtorFunc())
def msgId(self):
return self.msgCtorFunc() + "__ID"
def pqMsgId(self):
return "%s::%s" % (self.namespace, self.msgId())
def replyCtorFunc(self):
return "Reply_%s" % (self.decl.progname)
def pqReplyCtorFunc(self):
return "%s::%s" % (self.namespace, self.replyCtorFunc())
def replyId(self):
return self.replyCtorFunc() + "__ID"
def pqReplyId(self):
return "%s::%s" % (self.namespace, self.replyId())
def prettyReplyName(self, pfx=""):
return pfx + self.replyCtorFunc()
def promiseName(self):
name = self.baseName()
if self.decl.type.isCtor():
name += "Constructor"
name += "Promise"
return name
def resolverName(self):
return self.baseName() + "Resolver"
def actorDecl(self):
return self.params[0]
def makeCxxParams(
self, paramsems="in", returnsems="out", side=None, implicit=True, direction=None
):
"""Return a list of C++ decls per the spec'd configuration.
|params| and |returns| is the C++ semantics of those: 'in', 'out', or None."""
def makeDecl(d, sems):
if (
self.decl.type.tainted
and "NoTaint" not in d.attributes
and direction == "recv"
):
# Tainted types are passed by-value, allowing the receiver to move them if desired.
assert sems != "out"
return Decl(Type("Tainted", T=d.bareType(side)), d.name)
if sems == "in":
t = d.inType(side, direction)
# If this is the `recv` side, and we're not using "move"
# semantics, that means we're an alloc method, and cannot accept
# values by rvalue reference. Downgrade to an lvalue reference.
if direction == "recv" and t.rvalref:
t.rvalref = False
t.ref = True
return Decl(t, d.name)
elif sems == "move":
assert direction == "recv"
# For legacy reasons, use an rvalue reference when generating
# parameters for recv methods which accept arrays.
if d.ipdltype.isIPDL() and d.ipdltype.isArray():
t = d.bareType(side)
t.rvalref = True
return Decl(t, d.name)
return Decl(d.inType(side, direction), d.name)
elif sems == "out":
return Decl(d.outType(side), d.name)
else:
assert 0
def makeResolverDecl(returns):
return Decl(Type(self.resolverName(), rvalref=True), "aResolve")
def makeCallbackResolveDecl(returns):
if len(returns) > 1:
resolvetype = _tuple([d.bareType(side) for d in returns])
else:
resolvetype = returns[0].bareType(side)
return Decl(
Type("mozilla::ipc::ResolveCallback", T=resolvetype, rvalref=True),
"aResolve",
)
def makeCallbackRejectDecl(returns):
return Decl(Type("mozilla::ipc::RejectCallback", rvalref=True), "aReject")
cxxparams = []
if paramsems is not None:
cxxparams.extend([makeDecl(d, paramsems) for d in self.params])
if returnsems == "promise" and self.returns:
pass
elif returnsems == "callback" and self.returns:
cxxparams.extend(
[
makeCallbackResolveDecl(self.returns),
makeCallbackRejectDecl(self.returns),
]
)
elif returnsems == "resolver" and self.returns:
cxxparams.extend([makeResolverDecl(self.returns)])
elif returnsems is not None:
cxxparams.extend([makeDecl(r, returnsems) for r in self.returns])
if not implicit and self.decl.type.hasImplicitActorParam():
cxxparams = cxxparams[1:]
return cxxparams
def makeCxxArgs(
self, paramsems="in", retsems="out", retcallsems="out", implicit=True
):
assert not retcallsems or retsems # retcallsems => returnsems
cxxargs = []
if paramsems == "move":
# We don't std::move() RefPtr<T> types because current Recv*()
# implementors take these parameters as T*, and
# std::move(RefPtr<T>) doesn't coerce to T*.
# We also don't move NotNull, as it has no move constructor.
cxxargs.extend(
[
p.var()
if p.ipdltype.isRefcounted()
or (p.ipdltype.isIPDL() and p.ipdltype.isNotNull())
else ExprMove(p.var())
for p in self.params
]
)
elif paramsems == "in":
cxxargs.extend([p.var() for p in self.params])
else:
assert False
for ret in self.returns:
if retsems == "in":
if retcallsems == "in":
cxxargs.append(ret.var())
elif retcallsems == "out":
cxxargs.append(ExprAddrOf(ret.var()))
else:
assert 0
elif retsems == "out":
if retcallsems == "in":
cxxargs.append(ExprDeref(ret.var()))
elif retcallsems == "out":
cxxargs.append(ret.var())
else:
assert 0
elif retsems == "resolver":
pass
if retsems == "resolver":
cxxargs.append(ExprMove(ExprVar("resolver")))
if not implicit:
assert self.decl.type.hasImplicitActorParam()
cxxargs = cxxargs[1:]
return cxxargs
@staticmethod
def upgrade(messageDecl):
assert isinstance(messageDecl, ipdl.ast.MessageDecl)
if messageDecl.decl.type.hasImplicitActorParam():
messageDecl.params.insert(
0,
_HybridDecl(
ipdl.type.ActorType(messageDecl.decl.type.constructedType()),
"actor",
),
)
messageDecl.__class__ = MessageDecl
# --------------------------------------------------
def _usesShmem(p):
for md in p.messageDecls:
for param in md.inParams:
if ipdl.type.hasshmem(param.type):
return True
for ret in md.outParams:
if ipdl.type.hasshmem(ret.type):
return True
return False
def _subtreeUsesShmem(p):
if _usesShmem(p):
return True
ptype = p.decl.type
for mgd in ptype.manages:
if ptype is not mgd:
if _subtreeUsesShmem(mgd._ast):
return True
return False
class Protocol(ipdl.ast.Protocol):
def _ipdlmgrtype(self):
assert 1 == len(self.decl.type.managers)
for mgr in self.decl.type.managers:
return mgr
def managerActorType(self, side, ptr=False):
return Type(_actorName(self._ipdlmgrtype().name(), side), ptr=ptr)
def unregisterMethod(self, actorThis=None):
if actorThis is not None:
return ExprSelect(actorThis, "->", "Unregister")
return ExprVar("Unregister")
def removeManageeMethod(self):
return ExprVar("RemoveManagee")
def deallocManageeMethod(self):
return ExprVar("DeallocManagee")
def getChannelMethod(self):
return ExprVar("GetIPCChannel")
def callGetChannel(self, actorThis=None):
fn = self.getChannelMethod()
if actorThis is not None:
fn = ExprSelect(actorThis, "->", fn.name)
return ExprCall(fn)
def processingErrorVar(self):
assert self.decl.type.isToplevel()
return ExprVar("ProcessingError")
def shouldContinueFromTimeoutVar(self):
assert self.decl.type.isToplevel()
return ExprVar("ShouldContinueFromReplyTimeout")
def routingId(self, actorThis=None):
if self.decl.type.isToplevel():
return ExprVar("MSG_ROUTING_CONTROL")
if actorThis is not None:
return ExprCall(ExprSelect(actorThis, "->", "Id"))
return ExprCall(ExprVar("Id"))
def managerVar(self, thisexpr=None):
assert thisexpr is not None or not self.decl.type.isToplevel()
mvar = ExprCall(ExprVar("Manager"), args=[])
if thisexpr is not None:
mvar = ExprCall(ExprSelect(thisexpr, "->", "Manager"), args=[])
return mvar
def managedCxxType(self, actortype, side):
assert self.decl.type.isManagerOf(actortype)
return Type(_actorName(actortype.name(), side), ptr=True)
def managedMethod(self, actortype, side):
assert self.decl.type.isManagerOf(actortype)
return ExprVar("Managed" + _actorName(actortype.name(), side))
def managedVar(self, actortype, side):
assert self.decl.type.isManagerOf(actortype)
return ExprVar("mManaged" + _actorName(actortype.name(), side))
def managedVarType(self, actortype, side, const=False, ref=False):
assert self.decl.type.isManagerOf(actortype)
return _cxxManagedContainerType(
Type(_actorName(actortype.name(), side)), const=const, ref=ref
)
def subtreeUsesShmem(self):
return _subtreeUsesShmem(self)
@staticmethod
def upgrade(protocol):
assert isinstance(protocol, ipdl.ast.Protocol)
protocol.__class__ = Protocol
class TranslationUnit(ipdl.ast.TranslationUnit):
@staticmethod
def upgrade(tu):
assert isinstance(tu, ipdl.ast.TranslationUnit)
tu.__class__ = TranslationUnit
# -----------------------------------------------------------------------------
pod_types = {
"::int8_t": 1,
"::uint8_t": 1,
"::int16_t": 2,
"::uint16_t": 2,
"::int32_t": 4,
"::uint32_t": 4,
"::int64_t": 8,
"::uint64_t": 8,
"float": 4,
"double": 8,
}
max_pod_size = max(pod_types.values())
# We claim that all types we don't recognize are automatically "bigger"
# than pod types for ease of sorting.
pod_size_sentinel = max_pod_size * 2
def pod_size(ipdltype):
if not ipdltype.isCxx():
return pod_size_sentinel
return pod_types.get(ipdltype.fullname(), pod_size_sentinel)
class _DecorateWithCxxStuff(ipdl.ast.Visitor):
"""Phase 1 of lowering: decorate the IPDL AST with information
relevant to C++ code generation.
This pass results in an AST that is a poor man's "IR"; in reality, a
"hybrid" AST mainly consisting of IPDL nodes with new C++ info along
with some new IPDL/C++ nodes that are tuned for C++ codegen."""
def __init__(self):
self.visitedTus = set()
self.protocolName = None
def visitTranslationUnit(self, tu):
if tu not in self.visitedTus:
self.visitedTus.add(tu)
ipdl.ast.Visitor.visitTranslationUnit(self, tu)
if not isinstance(tu, TranslationUnit):
TranslationUnit.upgrade(tu)
def visitInclude(self, inc):
if inc.tu.filetype == "header":
inc.tu.accept(self)
def visitProtocol(self, pro):
self.protocolName = pro.name
Protocol.upgrade(pro)
return ipdl.ast.Visitor.visitProtocol(self, pro)
def visitStructDecl(self, sd):
if not isinstance(sd, StructDecl):
newfields = [_StructField(f.decl.type, f.name, sd) for f in sd.fields]
# Compute a permutation of the fields for in-memory storage such
# that the memory layout of the structure will be well-packed.
permutation = list(range(len(newfields)))
# Note that the results of `pod_size` ensure that non-POD fields
# sort before POD ones.
def size(idx):
return pod_size(newfields[idx].ipdltype)
permutation.sort(key=size, reverse=True)
sd.fields = newfields
sd.packed_field_order = permutation
StructDecl.upgrade(sd)
def visitUnionDecl(self, ud):
ud.components = [_UnionMember(ctype, ud) for ctype in ud.decl.type.components]
UnionDecl.upgrade(ud)
def visitDecl(self, decl):
return _HybridDecl(decl.type, decl.progname, decl.attributes)
def visitMessageDecl(self, md):
md.namespace = self.protocolName
md.params = [param.accept(self) for param in md.inParams]
md.returns = [ret.accept(self) for ret in md.outParams]
MessageDecl.upgrade(md)
# -----------------------------------------------------------------------------
def msgenums(protocol, pretty=False):
msgenum = TypeEnum("MessageType")
msgstart = _messageStartName(protocol.decl.type) + " << 16"
msgenum.addId(protocol.name + "Start", msgstart)
for md in protocol.messageDecls:
msgenum.addId(md.prettyMsgName() if pretty else md.msgId())
if md.hasReply():
msgenum.addId(md.prettyReplyName() if pretty else md.replyId())
msgenum.addId(protocol.name + "End")
return msgenum
class _GenerateProtocolCode(ipdl.ast.Visitor):
"""Creates code common to both the parent and child actors."""
def __init__(self):
self.protocol = None # protocol we're generating a class for
self.hdrfile = None # what will become Protocol.h
self.cppfile = None # what will become Protocol.cpp
self.cppIncludeHeaders = []
self.structUnionDefns = []
self.funcDefns = []
def lower(self, tu, cxxHeaderFile, cxxFile, segmentcapacitydict):
self.protocol = tu.protocol
self.hdrfile = cxxHeaderFile
self.cppfile = cxxFile
self.segmentcapacitydict = segmentcapacitydict
tu.accept(self)
def visitTranslationUnit(self, tu):
hf = self.hdrfile
hf.addthing(_DISCLAIMER)
hf.addthings(_includeGuardStart(hf))
hf.addthing(Whitespace.NL)
for inc in builtinHeaderIncludes:
self.visitBuiltinCxxInclude(inc)
# Compute the set of includes we need for declared structure/union
# classes for this protocol.
typesToIncludes = {}
for using in tu.using:
typestr = str(using.type)
if typestr not in typesToIncludes:
typesToIncludes[typestr] = using.header
else:
assert typesToIncludes[typestr] == using.header
aggregateTypeIncludes = set()
for su in tu.structsAndUnions:
typedeps = _ComputeTypeDeps(su.decl.type, typesToIncludes)
if isinstance(su, ipdl.ast.StructDecl):
aggregateTypeIncludes.add("mozilla/ipc/IPDLStructMember.h")
for f in su.fields:
f.ipdltype.accept(typedeps)
elif isinstance(su, ipdl.ast.UnionDecl):
for c in su.components:
c.ipdltype.accept(typedeps)
aggregateTypeIncludes.update(typedeps.includeHeaders)
if len(aggregateTypeIncludes) != 0:
hf.addthing(Whitespace.NL)
hf.addthings([Whitespace("// Headers for typedefs"), Whitespace.NL])
for headername in sorted(iter(aggregateTypeIncludes)):
hf.addthing(CppDirective("include", '"' + headername + '"'))
# Manually run Visitor.visitTranslationUnit. For dependency resolution
# we need to handle structs and unions separately.
for cxxInc in tu.cxxIncludes:
cxxInc.accept(self)
for inc in tu.includes:
inc.accept(self)
self.generateStructsAndUnions(tu)
for using in tu.builtinUsing:
using.accept(self)
for using in tu.using:
using.accept(self)
if tu.protocol:
tu.protocol.accept(self)
if tu.filetype == "header":
self.cppIncludeHeaders.append(_ipdlhHeaderName(tu) + ".h")
hf.addthing(Whitespace.NL)
hf.addthings(_includeGuardEnd(hf))
cf = self.cppfile
cf.addthings(
(
[_DISCLAIMER, Whitespace.NL]
+ [
CppDirective("include", '"' + h + '"')
for h in self.cppIncludeHeaders
]
+ [Whitespace.NL]
+ [
CppDirective("include", '"%s"' % filename)
for filename in ipdl.builtin.CppIncludes
]
+ [Whitespace.NL]
)
)
if self.protocol:
# construct the namespace into which we'll stick all our defns
ns = Namespace(self.protocol.name)
cf.addthing(_putInNamespaces(ns, self.protocol.namespaces))
ns.addstmts(([Whitespace.NL] + self.funcDefns + [Whitespace.NL]))
cf.addthings(self.structUnionDefns)
def visitBuiltinCxxInclude(self, inc):
self.hdrfile.addthing(CppDirective("include", '"' + inc.file + '"'))
def visitCxxInclude(self, inc):
self.cppIncludeHeaders.append(inc.file)
def visitInclude(self, inc):
if inc.tu.filetype == "header":
self.hdrfile.addthing(
CppDirective("include", '"' + _ipdlhHeaderName(inc.tu) + '.h"')
)
# Inherit cpp includes defined by imported header files, as they may
# be required to serialize an imported `using` type.
for cxxinc in inc.tu.cxxIncludes:
cxxinc.accept(self)
else:
self.cppIncludeHeaders += [
_protocolHeaderName(inc.tu.protocol, "parent") + ".h",
_protocolHeaderName(inc.tu.protocol, "child") + ".h",
]
def generateStructsAndUnions(self, tu):
"""Generate the definitions for all structs and unions. This will
re-order the declarations if needed in the C++ code such that
dependencies have already been defined."""
decls = OrderedDict()
for su in tu.structsAndUnions:
if isinstance(su, StructDecl):
which = "struct"
forwarddecls, fulldecltypes, cls = _generateCxxStruct(su)
traitsdecl, traitsdefns = _ParamTraits.structPickling(su.decl.type)
else:
assert isinstance(su, UnionDecl)
which = "union"
forwarddecls, fulldecltypes, cls = _generateCxxUnion(su)
traitsdecl, traitsdefns = _ParamTraits.unionPickling(su.decl.type)
clsdecl, methoddefns = _splitClassDeclDefn(cls)
# Store the declarations in the decls map so we can emit in
# dependency order.
decls[su.decl.type] = (
fulldecltypes,
[Whitespace.NL]
+ forwarddecls
+ [
Whitespace(
"""
//-----------------------------------------------------------------------------
// Declaration of the IPDL type |%s %s|
//
"""
% (which, su.name)
),
_putInNamespaces(clsdecl, su.namespaces),
]
+ [Whitespace.NL, traitsdecl],
)
self.structUnionDefns.extend(
[
Whitespace(
"""
//-----------------------------------------------------------------------------
// Method definitions for the IPDL type |%s %s|
//
"""
% (which, su.name)
),
_putInNamespaces(methoddefns, su.namespaces),
Whitespace.NL,
traitsdefns,
]
)
# Generate the declarations structs in dependency order.
def gen_struct(deps, defn):
for dep in deps:
if dep in decls:
d, t = decls[dep]
del decls[dep]
gen_struct(d, t)
self.hdrfile.addthings(defn)
while len(decls) > 0:
_, (d, t) = decls.popitem(False)
gen_struct(d, t)
def visitProtocol(self, p):
self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol, "") + ".h")
self.cppIncludeHeaders.append(
_protocolHeaderName(self.protocol, "Parent") + ".h"
)
self.cppIncludeHeaders.append(
_protocolHeaderName(self.protocol, "Child") + ".h"
)
# Forward declare our own actors.
self.hdrfile.addthings(
[
Whitespace.NL,
_makeForwardDeclForActor(p.decl.type, "Parent"),
_makeForwardDeclForActor(p.decl.type, "Child"),
]
)
self.hdrfile.addthing(
Whitespace(
"""
//-----------------------------------------------------------------------------
// Code common to %sChild and %sParent
//
"""
% (p.name, p.name)
)
)
# construct the namespace into which we'll stick all our decls
ns = Namespace(self.protocol.name)
self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces))
ns.addstmt(Whitespace.NL)
for func in self.genEndpointFuncs():
edecl, edefn = _splitFuncDeclDefn(func)
ns.addstmts([edecl, Whitespace.NL])
self.funcDefns.append(edefn)
# spit out message type enum and classes
msgenum = msgenums(self.protocol)
ns.addstmts([StmtDecl(Decl(msgenum, "")), Whitespace.NL])
for md in p.messageDecls:
decls = []
# Look up the segment capacity used for serializing this
# message. If the capacity is not specified, use '0' for
# the default capacity (defined in ipc_message.cc)
name = "%s::%s" % (md.namespace, md.decl.progname)
segmentcapacity = self.segmentcapacitydict.get(name, 0)
mfDecl, mfDefn = _splitFuncDeclDefn(
_generateMessageConstructor(md, segmentcapacity, p, forReply=False)
)
decls.append(mfDecl)
self.funcDefns.append(mfDefn)
if md.hasReply():
rfDecl, rfDefn = _splitFuncDeclDefn(
_generateMessageConstructor(md, 0, p, forReply=True)
)
decls.append(rfDecl)
self.funcDefns.append(rfDefn)
decls.append(Whitespace.NL)
ns.addstmts(decls)
ns.addstmts([Whitespace.NL, Whitespace.NL])
# Generate code for PFoo::CreateEndpoints.
def genEndpointFuncs(self):
p = self.protocol.decl.type
tparent = _cxxBareType(ActorType(p), "Parent", fq=True)
tchild = _cxxBareType(ActorType(p), "Child", fq=True)
def mkOverload(includepids):
params = []
if includepids:
params = [
Decl(Type("base::ProcessId"), "aParentDestPid"),
Decl(Type("base::ProcessId"), "aChildDestPid"),
]
params += [
Decl(
Type("mozilla::ipc::Endpoint<" + tparent.name + ">", ptr=True),
"aParent",
),
Decl(
Type("mozilla::ipc::Endpoint<" + tchild.name + ">", ptr=True),
"aChild",
),
]
openfunc = MethodDefn(
MethodDecl("CreateEndpoints", params=params, ret=Type.NSRESULT)
)
openfunc.addcode(
"""
return mozilla::ipc::CreateEndpoints(
mozilla::ipc::PrivateIPDLInterface(),
$,{args});
""",
args=[ExprVar(d.name) for d in params],
)
return openfunc
funcs = [mkOverload(True)]
if not p.hasOtherPid():
funcs.append(mkOverload(False))
return funcs
# --------------------------------------------------
cppPriorityList = list(
map(lambda src: src.upper() + "_PRIORITY", ipdl.ast.priorityList)
)
def _generateMessageConstructor(md, segmentSize, protocol, forReply=False):
if forReply:
clsname = md.replyCtorFunc()
msgid = md.replyId()
replyEnum = "REPLY"
prioEnum = cppPriorityList[md.decl.type.replyPrio]
else:
clsname = md.msgCtorFunc()
msgid = md.msgId()
replyEnum = "NOT_REPLY"
prioEnum = cppPriorityList[md.decl.type.prio]
nested = md.decl.type.nested
compress = md.decl.type.compress
lazySend = md.decl.type.lazySend
routingId = ExprVar("routingId")
func = FunctionDefn(
FunctionDecl(
clsname,
params=[Decl(Type("int32_t"), routingId.name)],
ret=Type("mozilla::UniquePtr<IPC::Message>"),
)
)
if not compress:
compression = "COMPRESSION_NONE"
elif compress.value == "all":
compression = "COMPRESSION_ALL"
else:
assert compress.value is None
compression = "COMPRESSION_ENABLED"
if lazySend:
lazySendEnum = "LAZY_SEND"
else:
lazySendEnum = "EAGER_SEND"
if nested == ipdl.ast.NOT_NESTED:
nestedEnum = "NOT_NESTED"
elif nested == ipdl.ast.INSIDE_SYNC_NESTED:
nestedEnum = "NESTED_INSIDE_SYNC"
else:
assert nested == ipdl.ast.INSIDE_CPOW_NESTED
nestedEnum = "NESTED_INSIDE_CPOW"
if md.decl.type.isSync():
syncEnum = "SYNC"
else:
syncEnum = "ASYNC"
if md.decl.type.isCtor():
ctorEnum = "CONSTRUCTOR"
else:
ctorEnum = "NOT_CONSTRUCTOR"
def messageEnum(valname):
return ExprVar("IPC::Message::" + valname)
flags = ExprCall(
ExprVar("IPC::Message::HeaderFlags"),
args=[
messageEnum(nestedEnum),
messageEnum(prioEnum),
messageEnum(compression),
messageEnum(lazySendEnum),
messageEnum(ctorEnum),
messageEnum(syncEnum),
messageEnum(replyEnum),
],
)
segmentSize = int(segmentSize)
if not segmentSize:
segmentSize = 0
func.addstmt(
StmtReturn(
ExprCall(
ExprVar("IPC::Message::IPDLMessage"),
args=[
routingId,
ExprVar(msgid),
ExprLiteral.Int(int(segmentSize)),
flags,
],
)
)
)
return func
# --------------------------------------------------
class _ParamTraits:
var = ExprVar("aVar")
writervar = ExprVar("aWriter")
readervar = ExprVar("aReader")
@classmethod
def ifsideis(cls, rdrwtr, side, then, els=None):
ifstmt = StmtIf(
ExprBinary(
_cxxSide(side),
"==",
ExprCode("${rdrwtr}->GetActor()->GetSide()", rdrwtr=rdrwtr),
)
)
ifstmt.addifstmt(then)
if els is not None:
ifstmt.addelsestmt(els)
return ifstmt
@classmethod
def fatalError(cls, rdrwtr, reason):
return StmtCode(
"${rdrwtr}->FatalError(${reason});",
rdrwtr=rdrwtr,
reason=ExprLiteral.String(reason),
)
@classmethod
def writeSentinel(cls, writervar, sentinelKey):
return [
Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True),
StmtExpr(
ExprCall(
ExprSelect(writervar, "->", "WriteSentinel"),
args=[ExprLiteral.Int(hashfunc(sentinelKey))],
)
),
]
@classmethod
def readSentinel(cls, readervar, sentinelKey, sentinelFail):
# Read the sentinel
read = ExprCall(
ExprSelect(readervar, "->", "ReadSentinel"),
args=[ExprLiteral.Int(hashfunc(sentinelKey))],
)
ifsentinel = StmtIf(ExprNot(read))
ifsentinel.addifstmts(sentinelFail)
return [
Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True),
ifsentinel,
]
@classmethod
def write(cls, var, writervar, ipdltype=None):
if ipdltype and _cxxTypeNeedsMoveForSend(ipdltype):
var = ExprMove(var)
return ExprCall(ExprVar("IPC::WriteParam"), args=[writervar, var])
@classmethod
def checkedWrite(cls, ipdltype, var, writervar, sentinelKey):
assert sentinelKey
block = Block()
block.addstmts(
[
StmtExpr(cls.write(var, writervar, ipdltype)),
]
)
block.addstmts(cls.writeSentinel(writervar, sentinelKey))
return block
@classmethod
def bulkSentinelKey(cls, fields):
return " | ".join(f.basename for f in fields)
@classmethod
def checkedBulkWrite(cls, var, size, fields):
block = Block()
first = fields[0]
block.addstmts(
[
StmtExpr(
ExprCall(
ExprSelect(cls.writervar, "->", "WriteBytes"),
args=[
ExprAddrOf(
ExprCall(first.getMethod(thisexpr=var, sel="."))
),
ExprLiteral.Int(size * len(fields)),
],
)
)
]
)
block.addstmts(cls.writeSentinel(cls.writervar, cls.bulkSentinelKey(fields)))
return block
@classmethod
def checkedBulkRead(cls, var, size, fields):
block = Block()
first = fields[0]
readbytes = ExprCall(
ExprSelect(cls.readervar, "->", "ReadBytesInto"),
args=[
ExprAddrOf(ExprCall(first.getMethod(thisexpr=var, sel="->"))),
ExprLiteral.Int(size * len(fields)),
],
)
ifbad = StmtIf(ExprNot(readbytes))
errmsg = "Error bulk reading fields from %s" % first.ipdltype.name()
ifbad.addifstmts(
[cls.fatalError(cls.readervar, errmsg), StmtReturn(readResultError())]
)
block.addstmt(ifbad)
block.addstmts(
cls.readSentinel(
cls.readervar,
cls.bulkSentinelKey(fields),
errfnSentinel(readResultError())(errmsg),
)
)
return block
@classmethod
def checkedRead(
cls,
ipdltype,
cxxtype,
var,
readervar,
errfn,
paramtype,
sentinelKey,
errfnSentinel,
):
assert isinstance(var, ExprVar)
if not isinstance(paramtype, list):
paramtype = ["Error deserializing " + paramtype]
block = Block()
# Read the data
block.addcode(
"""
auto ${maybevar} = IPC::ReadParam<${ty}>(${reader});
if (!${maybevar}) {
$*{errfn}
}
auto& ${var} = *${maybevar};
""",
maybevar=ExprVar("maybe__" + var.name),
ty=cxxtype,
reader=readervar,
errfn=errfn(*paramtype),
var=var,
)
block.addstmts(
cls.readSentinel(readervar, sentinelKey, errfnSentinel(*paramtype))
)
return block
# Helper wrapper for checkedRead for use within _ParamTraits
@classmethod
def _checkedRead(cls, ipdltype, cxxtype, var, sentinelKey, what):
def errfn(msg):
return [cls.fatalError(cls.readervar, msg), StmtReturn(readResultError())]
return cls.checkedRead(
ipdltype,
cxxtype,
var,
cls.readervar,
errfn=errfn,
paramtype=what,
sentinelKey=sentinelKey,
errfnSentinel=errfnSentinel(readResultError()),
)
@classmethod
def generateDecl(cls, fortype, write, read, needsmove=False):
# ParamTraits impls are selected ignoring constness, and references.
pt = Class(
"ParamTraits",
specializes=Type(
fortype.name, T=fortype.T, inner=fortype.inner, ptr=fortype.ptr
),
struct=True,
)
# typedef T paramType;
pt.addstmt(Typedef(fortype, "paramType"))
# static void Write(Message*, const T&);
if needsmove:
intype = Type("paramType", rvalref=True)
else:
intype = Type("paramType", ref=True, const=True)
writemthd = MethodDefn(
MethodDecl(
"Write",
params=[
Decl(Type("IPC::MessageWriter", ptr=True), cls.writervar.name),
Decl(intype, cls.var.name),
],
methodspec=MethodSpec.STATIC,
)
)
writemthd.addstmts(write)
pt.addstmt(writemthd)
# static ReadResult<T> Read(MessageReader*);
readmthd = MethodDefn(
MethodDecl(
"Read",
params=[
Decl(Type("IPC::MessageReader", ptr=True), cls.readervar.name),
],
ret=Type("IPC::ReadResult<paramType>"),
methodspec=MethodSpec.STATIC,
)
)
readmthd.addstmts(read)
pt.addstmt(readmthd)
# Split the class into declaration and definition
clsdecl, methoddefns = _splitClassDeclDefn(pt)
namespaces = [Namespace("IPC")]
clsns = _putInNamespaces(clsdecl, namespaces)
defns = _putInNamespaces(methoddefns, namespaces)
return clsns, defns
@classmethod
def actorPickling(cls, actortype, side):
"""Generates pickling for IPDL actors. This is a |nullable| deserializer.
Write and read callers will perform nullability validation."""
cxxtype = _cxxBareType(actortype, side, fq=True)
write = StmtCode(
"""
MOZ_RELEASE_ASSERT(
${writervar}->GetActor(),
"Cannot serialize managed actors without an actor");
int32_t id;
if (!${var}) {
id = 0; // kNullActorId
} else {
id = ${var}->Id();
if (id == 1) { // kFreedActorId
${var}->FatalError("Actor has been |delete|d");
}
MOZ_RELEASE_ASSERT(
${writervar}->GetActor()->GetIPCChannel() == ${var}->GetIPCChannel(),
"Actor must be from the same channel as the"
" actor it's being sent over");
MOZ_RELEASE_ASSERT(
${var}->CanSend(),
"Actor must still be open when sending");
}
${write};
""",
var=cls.var,
writervar=cls.writervar,
write=cls.write(ExprVar("id"), cls.writervar),
)
# bool Read(..) impl
read = StmtCode(
"""
MOZ_RELEASE_ASSERT(
${readervar}->GetActor(),
"Cannot deserialize managed actors without an actor");
mozilla::Maybe<mozilla::ipc::IProtocol*> actor = ${readervar}->GetActor()
->ReadActor(${readervar}, true, ${actortype}, ${protocolid});
if (actor.isSome()) {
return static_cast<${cxxtype}>(actor.ref());
}
return {};
""",
readervar=cls.readervar,
actortype=ExprLiteral.String(actortype.name()),
protocolid=_protocolId(actortype),
cxxtype=cxxtype,
)
return cls.generateDecl(cxxtype, [write], [read])
@classmethod
def structPickling(cls, structtype):
sd = structtype._ast
# NOTE: Not using _cxxBareType here as we don't have a side
cxxtype = Type(structtype.fullname())
write = []
read = []
# First serialize/deserialize all non-pod data in IPDL order. These need
# to be read/written first because they'll be used to invoke the IPDL
# struct's constructor.
ctorargs = []
for f in sd.fields_ipdl_order():
if pod_size(f.ipdltype) == pod_size_sentinel:
write.append(
cls.checkedWrite(
f.ipdltype,
ExprCall(f.getMethod(thisexpr=cls.var, sel=".")),
cls.writervar,
sentinelKey=f.basename,
)
)
read.append(
cls._checkedRead(
f.ipdltype,
f.bareType(fq=True),
f.argVar(),
f.basename,
"'"
+ f.getMethod().name
+ "' "
+ "("
+ f.ipdltype.name()
+ ") member of "
+ "'"
+ structtype.name()
+ "'",
)
)
if _cxxTypeCanMove(f.ipdltype):
ctorargs.append(ExprMove(f.argVar()))
else:
ctorargs.append(f.argVar())
else:
# We're going to bulk-read in this value later, so we'll just
# zero-initialize it for now.
ctorargs.append(ExprCode("${type}{0}", type=f.bareType(fq=True)))
resultvar = ExprVar("result__")
read.append(
StmtDecl(
Decl(_cxxReadResultType(Type("paramType")), resultvar.name),
initargs=[ExprVar("std::in_place")] + ctorargs,
)
)
# After non-pod data, bulk read/write pod data in member order. This has
# to be done after the result has been constructed, so that we have
# somewhere to read into.
for size, fields in itertools.groupby(
sd.fields_member_order(), lambda f: pod_size(f.ipdltype)
):
if size != pod_size_sentinel:
fields = list(fields)
write.append(cls.checkedBulkWrite(cls.var, size, fields))
read.append(cls.checkedBulkRead(resultvar, size, fields))
read.append(StmtReturn(resultvar))
return cls.generateDecl(
cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(structtype)
)
@classmethod
def unionPickling(cls, uniontype):
# NOTE: Not using _cxxBareType here as we don't have a side
cxxtype = Type(uniontype.fullname())
ud = uniontype._ast
# Use typedef to set up an alias so it's easier to reference the struct type.
alias = "union__"
typevar = ExprVar("type")
prelude = [
Typedef(cxxtype, alias),
]
writeswitch = StmtSwitch(typevar)
write = prelude + [
StmtDecl(Decl(Type.INT, typevar.name), init=ud.callType(cls.var)),
cls.checkedWrite(
None, typevar, cls.writervar, sentinelKey=uniontype.name()
),
Whitespace.NL,
writeswitch,
]
readswitch = StmtSwitch(typevar)
read = prelude + [
cls._checkedRead(
None,
Type.INT,
typevar,
uniontype.name(),
"type of union " + uniontype.name(),
),
Whitespace.NL,
readswitch,
]
for c in ud.components:
caselabel = CaseLabel(alias + "::" + c.enum())
origenum = c.enum()
writecase = StmtBlock()
wstmt = cls.checkedWrite(
c.ipdltype,
ExprCall(ExprSelect(cls.var, ".", c.getTypeName())),
cls.writervar,
sentinelKey=c.enum(),
)
writecase.addstmts([wstmt, StmtReturn()])
writeswitch.addcase(caselabel, writecase)
readcase = StmtBlock()
tmpvar = ExprVar("tmp")
readcase.addstmts(
[
cls._checkedRead(
c.ipdltype,
c.bareType(fq=True),
tmpvar,
origenum,
"variant " + origenum + " of union " + uniontype.name(),
),
StmtReturn(ExprMove(tmpvar)),
]
)
readswitch.addcase(caselabel, readcase)
# Add the error default case
writeswitch.addcase(
DefaultLabel(),
StmtBlock(
[
cls.fatalError(
cls.writervar, "unknown variant of union " + uniontype.name()
),
StmtReturn(),
]
),
)
readswitch.addcase(
DefaultLabel(),
StmtBlock(
[
cls.fatalError(
cls.readervar, "unknown variant of union " + uniontype.name()
),
StmtReturn(readResultError()),
]
),
)
return cls.generateDecl(
cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(uniontype)
)
# --------------------------------------------------
class _ComputeTypeDeps(TypeVisitor):
"""Pass that gathers the C++ types that a particular IPDL type
(recursively) depends on. There are three kinds of dependencies: (i)
types that need forward declaration; (ii) types that need a |using|
stmt; (iii) IPDL structs or unions which must be fully declared
before this struct. Some types generate multiple kinds."""
def __init__(self, fortype, typesToIncludes=None):
ipdl.type.TypeVisitor.__init__(self)
self.usingTypedefs = []
self.forwardDeclStmts = []
self.fullDeclTypes = []
self.includeHeaders = set()
self.fortype = fortype
self.typesToIncludes = typesToIncludes
def maybeTypedef(self, fqname, name, templateargs=[]):
assert fqname.startswith("::")
if fqname != name:
self.usingTypedefs.append(Typedef(Type(fqname), name, templateargs))
if self.typesToIncludes is not None and fqname in self.typesToIncludes:
self.includeHeaders.add(self.typesToIncludes[fqname])
def visitImportedCxxType(self, t):
if t in self.visited:
return
self.visited.add(t)
self.maybeTypedef(t.fullname(), t.name())
def visitActorType(self, t):
if t in self.visited:
return
self.visited.add(t)
fqname, name = t.fullname(), t.name()
self.includeHeaders.add("mozilla/ipc/SideVariant.h")
self.maybeTypedef(_actorName(fqname, "Parent"), _actorName(name, "Parent"))
self.maybeTypedef(_actorName(fqname, "Child"), _actorName(name, "Child"))
self.forwardDeclStmts.extend(
[
_makeForwardDeclForActor(t.protocol, "parent"),
Whitespace.NL,
_makeForwardDeclForActor(t.protocol, "child"),
Whitespace.NL,
]
)
def visitStructOrUnionType(self, su, defaultVisit):
if su in self.visited or su == self.fortype:
return
self.visited.add(su)
self.maybeTypedef(su.fullname(), su.name())
# Mutually recursive fields in unions are behind indirection, so we only
# need a forward decl, and don't need a full type declaration.
if isinstance(self.fortype, UnionType) and self.fortype.mutuallyRecursiveWith(
su
):
self.forwardDeclStmts.append(_makeForwardDecl(su))
else:
self.fullDeclTypes.append(su)
return defaultVisit(self, su)
def visitStructType(self, t):
return self.visitStructOrUnionType(t, TypeVisitor.visitStructType)
def visitUnionType(self, t):
return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType)
def visitArrayType(self, t):
return TypeVisitor.visitArrayType(self, t)
def visitMaybeType(self, m):
return TypeVisitor.visitMaybeType(self, m)
def visitShmemType(self, s):
if s in self.visited:
return
self.visited.add(s)
self.maybeTypedef("::mozilla::ipc::Shmem", "Shmem")
def visitByteBufType(self, s):
if s in self.visited:
return
self.visited.add(s)
self.maybeTypedef("::mozilla::ipc::ByteBuf", "ByteBuf")
def visitFDType(self, s):
if s in self.visited:
return
self.visited.add(s)
self.maybeTypedef("::mozilla::ipc::FileDescriptor", "FileDescriptor")
def visitEndpointType(self, s):
if s in self.visited:
return
self.visited.add(s)
self.includeHeaders.add("mozilla/ipc/Endpoint.h")
self.maybeTypedef("::mozilla::ipc::Endpoint", "Endpoint", ["FooSide"])
self.visitActorType(s.actor)
def visitManagedEndpointType(self, s):
if s in self.visited:
return
self.visited.add(s)
self.includeHeaders.add("mozilla/ipc/Endpoint.h")
self.maybeTypedef(
"::mozilla::ipc::ManagedEndpoint", "ManagedEndpoint", ["FooSide"]
)
self.visitActorType(s.actor)
def visitUniquePtrType(self, s):
return TypeVisitor.visitUniquePtrType(self, s)
def visitVoidType(self, v):
assert 0
def visitMessageType(self, v):
assert 0
def visitProtocolType(self, v):
assert 0
def _fieldStaticAssertions(sd):
staticasserts = []
for size, fields in itertools.groupby(
sd.fields_member_order(), lambda f: pod_size(f.ipdltype)
):
if size == pod_size_sentinel:
continue
fields = list(fields)
if len(fields) == 1:
continue
staticasserts.append(
StmtCode(
"""
static_assert(
(offsetof(${struct}, ${last}) - offsetof(${struct}, ${first})) == ${expected},
"Bad assumptions about field layout!");
""",
struct=sd.name,
first=fields[0].memberVar(),
last=fields[-1].memberVar(),
expected=ExprLiteral.Int(size * (len(fields) - 1)),
)
)
return staticasserts
def _generateCxxStruct(sd):
""" """
# compute all the typedefs and forward decls we need to make
gettypedeps = _ComputeTypeDeps(sd.decl.type)
for f in sd.fields:
f.ipdltype.accept(gettypedeps)
usingTypedefs = gettypedeps.usingTypedefs
forwarddeclstmts = gettypedeps.forwardDeclStmts
fulldecltypes = gettypedeps.fullDeclTypes
struct = Class(sd.name, final=True)
struct.addstmts([Label.PRIVATE] + usingTypedefs + [Whitespace.NL, Label.PUBLIC])
constreftype = Type(sd.name, const=True, ref=True)
# Struct()
# We want the default constructor to be declared if it is available, but
# some of our members may not be default-constructible. Silence the
# warning which clang generates in that case.
#
# Members which need value initialization will be handled by wrapping
# the member in a template type when declaring them.
struct.addcode(
"""
#ifdef __clang__
# pragma clang diagnostic push
# if __has_warning("-Wdefaulted-function-deleted")
# pragma clang diagnostic ignored "-Wdefaulted-function-deleted"
# endif
#endif
${name}() = default;
#ifdef __clang__
# pragma clang diagnostic pop
#endif
""",
name=sd.name,
)
# If this is an empty struct (no fields), then the default ctor
# and "create-with-fields" ctors are equivalent.
if len(sd.fields):
assert len(sd.fields) == len(sd.packed_field_order)
# Struct(const field1& _f1, ...)
valctor = ConstructorDefn(
ConstructorDecl(
sd.name,
params=[
Decl(
f.forceMoveType()
if _cxxTypeNeedsMoveForData(f.ipdltype)
else f.constRefType(),
f.argVar().name,
)
for f in sd.fields_ipdl_order()
],
force_inline=True,
)
)
valctor.memberinits = []
for f in sd.fields_member_order():
arg = f.argVar()
if _cxxTypeNeedsMoveForData(f.ipdltype):
arg = ExprMove(arg)
valctor.memberinits.append(ExprMemberInit(f.memberVar(), args=[arg]))
struct.addstmts([valctor, Whitespace.NL])
# If a constructor which moves each argument would be different from the
# `const T&` version, also generate that constructor.
if not all(
_cxxTypeNeedsMoveForData(f.ipdltype) or not _cxxTypeCanMove(f.ipdltype)
for f in sd.fields_ipdl_order()
):
# Struct(field1&& _f1, ...)
valmovector = ConstructorDefn(
ConstructorDecl(
sd.name,
params=[
Decl(
f.forceMoveType()
if _cxxTypeCanMove(f.ipdltype)
else f.constRefType(),
f.argVar().name,
)
for f in sd.fields_ipdl_order()
],
force_inline=True,
)
)
valmovector.memberinits = []
for f in sd.fields_member_order():
arg = f.argVar()
if _cxxTypeCanMove(f.ipdltype):
arg = ExprMove(arg)
valmovector.memberinits.append(
ExprMemberInit(f.memberVar(), args=[arg])
)
struct.addstmts([valmovector, Whitespace.NL])
# The default copy, move, and assignment constructors, and the default
# destructor, will do the right thing.
if "Comparable" in sd.attributes:
# bool operator==(const Struct& _o)
ovar = ExprVar("_o")
opeqeq = MethodDefn(
MethodDecl(
"operator==",
params=[Decl(constreftype, ovar.name)],
ret=Type.BOOL,
const=True,
)
)
for f in sd.fields_ipdl_order():
ifneq = StmtIf(
ExprNot(
ExprBinary(
ExprCall(f.getMethod()), "==", ExprCall(f.getMethod(ovar))
)
)
)
ifneq.addifstmt(StmtReturn.FALSE)
opeqeq.addstmt(ifneq)
opeqeq.addstmt(StmtReturn.TRUE)
struct.addstmts([opeqeq, Whitespace.NL])
# bool operator!=(const Struct& _o)
opneq = MethodDefn(
MethodDecl(
"operator!=",
params=[Decl(constreftype, ovar.name)],
ret=Type.BOOL,
const=True,
)
)
opneq.addstmt(StmtReturn(ExprNot(ExprCall(ExprVar("operator=="), args=[ovar]))))
struct.addstmts([opneq, Whitespace.NL])
# field1& f1()
# const field1& f1() const
for f in sd.fields_ipdl_order():
get = MethodDefn(
MethodDecl(
f.getMethod().name, params=[], ret=f.refType(), force_inline=True
)
)
get.addstmt(StmtReturn(f.refExpr()))
getconstdecl = deepcopy(get.decl)
getconstdecl.ret = f.constRefType()
getconstdecl.const = True
getconst = MethodDefn(getconstdecl)
getconst.addstmt(StmtReturn(f.constRefExpr()))
struct.addstmts([get, getconst, Whitespace.NL])
# private:
struct.addstmt(Label.PRIVATE)
# Static assertions to ensure our assumptions about field layout match
# what the compiler is actually producing. We define this as a member
# function, rather than throwing the assertions in the constructor or
# similar, because we don't want to evaluate the static assertions every
# time the header file containing the structure is included.
staticasserts = _fieldStaticAssertions(sd)
if staticasserts:
method = MethodDefn(
MethodDecl("StaticAssertions", params=[], ret=Type.VOID, const=True)
)
method.addstmts(staticasserts)
struct.addstmts([method])
# members
struct.addstmts(
[
StmtDecl(Decl(_effectiveMemberType(f), f.memberVar().name))
for f in sd.fields_member_order()
]
)
return forwarddeclstmts, fulldecltypes, struct
def _effectiveMemberType(f):
effective_type = f.bareType()
# Structs must be copyable for backwards compatibility reasons, so we use
# CopyableTArray<T> as their member type for arrays. This is not exposed
# in the method signatures, these keep using nsTArray<T>, which is a base
# class of CopyableTArray<T>.
if effective_type.name == "nsTArray":
effective_type.name = "CopyableTArray"
return Type("::mozilla::ipc::IPDLStructMember", T=[effective_type])
# --------------------------------------------------
def _generateCxxUnion(ud):
# This Union class basically consists of a type (enum) and a
# union for storage. The union can contain POD and non-POD
# types. Each type needs a copy/move ctor, assignment operators,
# and dtor.
#
# Rather than templating this class and only providing
# specializations for the types we support, which is slightly
# "unsafe" in that C++ code can add additional specializations
# without the IPDL compiler's knowledge, we instead explicitly
# implement non-templated methods for each supported type.
#
# The one complication that arises is that C++, for arcane
# reasons, does not allow the placement destructor of a
# builtin type, like int, to be directly invoked. So we need
# to hack around this by internally typedef'ing all
# constituent types. Sigh.
#
# So, for each type, this "Union" class needs:
# (private)
# - entry in the type enum
# - entry in the storage union
# - [type]ptr() method to get a type* from the underlying union
# - same as above to get a const type*
# - typedef to hack around placement delete limitations
# (public)
# - placement delete case for dtor
# - copy ctor
# - move ctor
# - case in generic copy ctor
# - copy operator= impl
# - move operator= impl
# - case in generic operator=
# - operator [type&]
# - operator [const type&] const
# - [type&] get_[type]()
# - [const type&] get_[type]() const
#
cls = Class(ud.name, final=True)
# const Union&, i.e., Union type with inparam semantics
inClsType = Type(ud.name, const=True, ref=True)
refClsType = Type(ud.name, ref=True)
rvalueRefClsType = Type(ud.name, rvalref=True)
typetype = Type("Type")
valuetype = Type("Value")
mtypevar = ExprVar("mType")
mvaluevar = ExprVar("mValue")
maybedtorvar = ExprVar("MaybeDestroy")
assertsanityvar = ExprVar("AssertSanity")
tnonevar = ExprVar("T__None")
tlastvar = ExprVar("T__Last")
def callAssertSanity(uvar=None, expectTypeVar=None):
func = assertsanityvar
args = []
if uvar is not None:
func = ExprSelect(uvar, ".", assertsanityvar.name)
if expectTypeVar is not None:
args.append(expectTypeVar)
return ExprCall(func, args=args)
def maybeDestroy():
return StmtExpr(ExprCall(maybedtorvar))
# compute all the typedefs and forward decls we need to make
gettypedeps = _ComputeTypeDeps(ud.decl.type)
for c in ud.components:
c.ipdltype.accept(gettypedeps)
usingTypedefs = gettypedeps.usingTypedefs
forwarddeclstmts = gettypedeps.forwardDeclStmts
fulldecltypes = gettypedeps.fullDeclTypes
# the |Type| enum, used to switch on the discunion's real type
cls.addstmt(Label.PUBLIC)
typeenum = TypeEnum(typetype.name)
typeenum.addId(tnonevar.name, 0)
firstid = ud.components[0].enum()
typeenum.addId(firstid, 1)
for c in ud.components[1:]:
typeenum.addId(c.enum())
typeenum.addId(tlastvar.name, ud.components[-1].enum())
cls.addstmts([StmtDecl(Decl(typeenum, "")), Whitespace.NL])
cls.addstmt(Label.PRIVATE)
cls.addstmts(
usingTypedefs
# hacky typedef's that allow placement dtors of builtins
+ [Typedef(c.internalType(), c.typedef()) for c in ud.components]
)
cls.addstmt(Whitespace.NL)
# the C++ union the discunion use for storage
valueunion = TypeUnion(valuetype.name)
for c in ud.components:
valueunion.addComponent(c.unionType(), c.name)
cls.addstmts([StmtDecl(Decl(valueunion, "")), Whitespace.NL])
# for each constituent type T, add private accessors that
# return a pointer to the Value union storage casted to |T*|
# and |const T*|
for c in ud.components:
getptr = MethodDefn(
MethodDecl(
c.getPtrName(), params=[], ret=c.ptrToInternalType(), force_inline=True
)
)
getptr.addstmt(StmtReturn(c.ptrToSelfExpr()))
getptrconst = MethodDefn(
MethodDecl(
c.getConstPtrName(),
params=[],
ret=c.constPtrToType(),
const=True,
force_inline=True,
)
)
getptrconst.addstmt(StmtReturn(c.constptrToSelfExpr()))
cls.addstmts([getptr, getptrconst])
cls.addstmt(Whitespace.NL)
# add a helper method that invokes the placement dtor on the
# current underlying value, only if |aNewType| is different
# than the current type, and returns true if the underlying
# value needs to be re-constructed
maybedtor = MethodDefn(MethodDecl(maybedtorvar.name, ret=Type.VOID))
# wasn't /actually/ dtor'd, but it needs to be re-constructed
ifnone = StmtIf(ExprBinary(mtypevar, "==", tnonevar))
ifnone.addifstmt(StmtReturn())
# need to destroy. switch on underlying type
dtorswitch = StmtSwitch(mtypevar)
for c in ud.components:
dtorswitch.addcase(
CaseLabel(c.enum()), StmtBlock([StmtExpr(c.callDtor()), StmtBreak()])
)
dtorswitch.addcase(
DefaultLabel(), StmtBlock([_logicError("not reached"), StmtBreak()])
)
maybedtor.addstmts([ifnone, dtorswitch])
cls.addstmts([maybedtor, Whitespace.NL])
# add helper methods that ensure the discunion has a
# valid type
sanity = MethodDefn(
MethodDecl(assertsanityvar.name, ret=Type.VOID, const=True, force_inline=True)
)
sanity.addstmts(
[
_abortIfFalse(ExprBinary(tnonevar, "<=", mtypevar), "invalid type tag"),
_abortIfFalse(ExprBinary(mtypevar, "<=", tlastvar), "invalid type tag"),
]
)
cls.addstmt(sanity)
atypevar = ExprVar("aType")
sanity2 = MethodDefn(
MethodDecl(
assertsanityvar.name,
params=[Decl(typetype, atypevar.name)],
ret=Type.VOID,
const=True,
force_inline=True,
)
)
sanity2.addstmts(
[
StmtExpr(ExprCall(assertsanityvar)),
_abortIfFalse(ExprBinary(mtypevar, "==", atypevar), "unexpected type tag"),
]
)
cls.addstmts([sanity2, Whitespace.NL])
# ---- begin public methods -----
# Union() default ctor
cls.addstmts(
[
Label.PUBLIC,
ConstructorDefn(
ConstructorDecl(ud.name, force_inline=True),
memberinits=[ExprMemberInit(mtypevar, [tnonevar])],
),
Whitespace.NL,
]
)
# Union(const T&) copy & Union(T&&) move ctors
othervar = ExprVar("aOther")
for c in ud.components:
if not _cxxTypeNeedsMoveForData(c.ipdltype):
copyctor = ConstructorDefn(
ConstructorDecl(ud.name, params=[Decl(c.constRefType(), othervar.name)])
)
copyctor.addstmts(
[
StmtExpr(c.callCtor(othervar)),
StmtExpr(ExprAssn(mtypevar, c.enumvar())),
]
)
cls.addstmts([copyctor, Whitespace.NL])
if not _cxxTypeCanMove(c.ipdltype):
continue
movector = ConstructorDefn(
ConstructorDecl(ud.name, params=[Decl(c.forceMoveType(), othervar.name)])
)
movector.addstmts(
[
StmtExpr(c.callCtor(ExprMove(othervar))),
StmtExpr(ExprAssn(mtypevar, c.enumvar())),
]
)
cls.addstmts([movector, Whitespace.NL])
unionNeedsMove = any(_cxxTypeNeedsMoveForData(c.ipdltype) for c in ud.components)
# Union(const Union&) copy ctor
if not unionNeedsMove:
copyctor = ConstructorDefn(
ConstructorDecl(ud.name, params=[Decl(inClsType, othervar.name)])
)
othertype = ud.callType(othervar)
copyswitch = StmtSwitch(othertype)
for c in ud.components:
copyswitch.addcase(
CaseLabel(c.enum()),
StmtBlock(
[
StmtExpr(
c.callCtor(
ExprCall(
ExprSelect(othervar, ".", c.getConstTypeName())
)
)
),
StmtBreak(),
]
),
)
copyswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()]))
copyswitch.addcase(
DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()])
)
copyctor.addstmts(
[
StmtExpr(callAssertSanity(uvar=othervar)),
copyswitch,
StmtExpr(ExprAssn(mtypevar, othertype)),
]
)
cls.addstmts([copyctor, Whitespace.NL])
# Union(Union&&) move ctor
movector = ConstructorDefn(
ConstructorDecl(ud.name, params=[Decl(rvalueRefClsType, othervar.name)])
)
othertypevar = ExprVar("t")
moveswitch = StmtSwitch(othertypevar)
for c in ud.components:
case = StmtBlock()
if c.recursive:
# This is sound as we set othervar.mTypeVar to T__None after the
# switch. The pointer in the union will be left dangling.
case.addstmts(
[
# ptr_C() = other.ptr_C()
StmtExpr(
ExprAssn(
c.callGetPtr(),
ExprCall(
ExprSelect(othervar, ".", ExprVar(c.getPtrName()))
),
)
)
]
)
else:
case.addstmts(
[
# new ... (Move(other.get_C()))
StmtExpr(
c.callCtor(
ExprMove(
ExprCall(ExprSelect(othervar, ".", c.getTypeName()))
)
)
),
# other.MaybeDestroy(T__None)
StmtExpr(ExprCall(ExprSelect(othervar, ".", maybedtorvar))),
]
)
case.addstmts([StmtBreak()])
moveswitch.addcase(CaseLabel(c.enum()), case)
moveswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()]))
moveswitch.addcase(
DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()])
)
movector.addstmts(
[
StmtExpr(callAssertSanity(uvar=othervar)),
StmtDecl(Decl(typetype, othertypevar.name), init=ud.callType(othervar)),
moveswitch,
StmtExpr(ExprAssn(ExprSelect(othervar, ".", mtypevar), tnonevar)),
StmtExpr(ExprAssn(mtypevar, othertypevar)),
]
)
cls.addstmts([movector, Whitespace.NL])
# ~Union()
dtor = DestructorDefn(DestructorDecl(ud.name))
dtor.addstmt(maybeDestroy())
cls.addstmts([dtor, Whitespace.NL])
# type()
typemeth = MethodDefn(
MethodDecl("type", ret=typetype, const=True, force_inline=True)
)
typemeth.addstmt(StmtReturn(mtypevar))
cls.addstmts([typemeth, Whitespace.NL])
# Union& operator= methods
rhsvar = ExprVar("aRhs")
for c in ud.components:
def opeqBody(rhs):
return [
# might need to placement-delete old value first
maybeDestroy(),
StmtExpr(c.callCtor(rhs)),
StmtExpr(ExprAssn(mtypevar, c.enumvar())),
StmtReturn(ExprDeref(ExprVar.THIS)),
]
if not _cxxTypeNeedsMoveForData(c.ipdltype):
# Union& operator=(const T&)
opeq = MethodDefn(
MethodDecl(
"operator=",
params=[Decl(c.constRefType(), rhsvar.name)],
ret=refClsType,
)
)
opeq.addstmts(opeqBody(rhsvar))
cls.addstmts([opeq, Whitespace.NL])
# Union& operator=(T&&)
if not _cxxTypeCanMove(c.ipdltype):
continue
opeq = MethodDefn(
MethodDecl(
"operator=",
params=[Decl(c.forceMoveType(), rhsvar.name)],
ret=refClsType,
)
)
opeq.addstmts(opeqBody(ExprMove(rhsvar)))
cls.addstmts([opeq, Whitespace.NL])
# Union& operator=(const Union&)
if not unionNeedsMove:
opeq = MethodDefn(
MethodDecl(
"operator=", params=[Decl(inClsType, rhsvar.name)], ret=refClsType
)
)
rhstypevar = ExprVar("t")
opeqswitch = StmtSwitch(rhstypevar)
for c in ud.components:
case = StmtBlock()
case.addstmts(
[
maybeDestroy(),
StmtExpr(
c.callCtor(
ExprCall(ExprSelect(rhsvar, ".", c.getConstTypeName()))
)
),
StmtBreak(),
]
)
opeqswitch.addcase(CaseLabel(c.enum()), case)
opeqswitch.addcase(
CaseLabel(tnonevar.name),
StmtBlock([maybeDestroy(), StmtBreak()]),
)
opeqswitch.addcase(
DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()])
)
opeq.addstmts(
[
StmtExpr(callAssertSanity(uvar=rhsvar)),
StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)),
opeqswitch,
StmtExpr(ExprAssn(mtypevar, rhstypevar)),
StmtReturn(ExprDeref(ExprVar.THIS)),
]
)
cls.addstmts([opeq, Whitespace.NL])
# Union& operator=(Union&&)
opeq = MethodDefn(
MethodDecl(
"operator=", params=[Decl(rvalueRefClsType, rhsvar.name)], ret=refClsType
)
)
rhstypevar = ExprVar("t")
opeqswitch = StmtSwitch(rhstypevar)
for c in ud.components:
case = StmtBlock()
if c.recursive:
case.addstmts(
[
maybeDestroy(),
StmtExpr(
ExprAssn(
c.callGetPtr(),
ExprCall(ExprSelect(rhsvar, ".", ExprVar(c.getPtrName()))),
)
),
]
)
else:
case.addstmts(
[
maybeDestroy(),
StmtExpr(
c.callCtor(
ExprMove(ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())))
)
),
# other.MaybeDestroy()
StmtExpr(ExprCall(ExprSelect(rhsvar, ".", maybedtorvar))),
]
)
case.addstmts([StmtBreak()])
opeqswitch.addcase(CaseLabel(c.enum()), case)
opeqswitch.addcase(
CaseLabel(tnonevar.name),
StmtBlock([maybeDestroy(), StmtBreak()]),
)
opeqswitch.addcase(
DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()])
)
opeq.addstmts(
[
StmtExpr(callAssertSanity(uvar=rhsvar)),
StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)),
opeqswitch,
StmtExpr(ExprAssn(ExprSelect(rhsvar, ".", mtypevar), tnonevar)),
StmtExpr(ExprAssn(mtypevar, rhstypevar)),
StmtReturn(ExprDeref(ExprVar.THIS)),
]
)
cls.addstmts([opeq, Whitespace.NL])
if "Comparable" in ud.attributes:
# bool operator==(const T&)
for c in ud.components:
opeqeq = MethodDefn(
MethodDecl(
"operator==",
params=[Decl(c.constRefType(), rhsvar.name)],
ret=Type.BOOL,
const=True,
)
)
opeqeq.addstmt(
StmtReturn(ExprBinary(ExprCall(ExprVar(c.getTypeName())), "==", rhsvar))
)
cls.addstmts([opeqeq, Whitespace.NL])
# bool operator==(const Union&)
opeqeq = MethodDefn(
MethodDecl(
"operator==",
params=[Decl(inClsType, rhsvar.name)],
ret=Type.BOOL,
const=True,
)
)
iftypesmismatch = StmtIf(ExprBinary(ud.callType(), "!=", ud.callType(rhsvar)))
iftypesmismatch.addifstmt(StmtReturn.FALSE)
opeqeq.addstmts([iftypesmismatch, Whitespace.NL])
opeqeqswitch = StmtSwitch(ud.callType())
for c in ud.components:
case = StmtBlock()
case.addstmt(
StmtReturn(
ExprBinary(
ExprCall(ExprVar(c.getTypeName())),
"==",
ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())),
)
)
)
opeqeqswitch.addcase(CaseLabel(c.enum()), case)
opeqeqswitch.addcase(
DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn.FALSE])
)
opeqeq.addstmt(opeqeqswitch)
cls.addstmts([opeqeq, Whitespace.NL])
# accessors for each type: operator T&, operator const T&,
# T& get(), const T& get()
for c in ud.components:
getValueVar = ExprVar(c.getTypeName())
getConstValueVar = ExprVar(c.getConstTypeName())
getvalue = MethodDefn(
MethodDecl(getValueVar.name, ret=c.refType(), force_inline=True)
)
getvalue.addstmts(
[
StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())),
StmtReturn(ExprDeref(c.callGetPtr())),
]
)
getconstvalue = MethodDefn(
MethodDecl(
getConstValueVar.name,
ret=c.constRefType(),
const=True,
force_inline=True,
)
)
getconstvalue.addstmts(
[
StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())),
StmtReturn(c.getConstValue()),
]
)
cls.addstmts([getvalue, getconstvalue])
optype = MethodDefn(MethodDecl("", typeop=c.refType(), force_inline=True))
optype.addstmt(StmtReturn(ExprCall(getValueVar)))
opconsttype = MethodDefn(
MethodDecl("", const=True, typeop=c.constRefType(), force_inline=True)
)
opconsttype.addstmt(StmtReturn(ExprCall(getConstValueVar)))
cls.addstmts([optype, opconsttype, Whitespace.NL])
# private vars
cls.addstmts(
[
Label.PRIVATE,
StmtDecl(Decl(valuetype, mvaluevar.name)),
StmtDecl(Decl(typetype, mtypevar.name)),
]
)
return forwarddeclstmts, fulldecltypes, cls
# -----------------------------------------------------------------------------
class _FindFriends(ipdl.ast.Visitor):
def __init__(self):
self.mytype = None # ProtocolType
self.vtype = None # ProtocolType
self.friends = set() # set<ProtocolType>
def findFriends(self, ptype):
self.mytype = ptype
for toplvl in ptype.toplevels():
self.walkDownTheProtocolTree(toplvl)
return self.friends
# TODO could make this into a _iterProtocolTreeHelper ...
def walkDownTheProtocolTree(self, ptype):
if ptype != self.mytype:
# don't want to |friend| ourself!
self.visit(ptype)
for mtype in ptype.manages:
if mtype is not ptype:
self.walkDownTheProtocolTree(mtype)
def visit(self, ptype):
# |vtype| is the type currently being visited
savedptype = self.vtype
self.vtype = ptype
ptype._ast.accept(self)
self.vtype = savedptype
def visitMessageDecl(self, md):
for it in self.iterActorParams(md):
if it.protocol == self.mytype:
self.friends.add(self.vtype)
def iterActorParams(self, md):
for param in md.inParams:
for actor in ipdl.type.iteractortypes(param.type):
yield actor
for ret in md.outParams:
for actor in ipdl.type.iteractortypes(ret.type):
yield actor
class _GenerateProtocolActorCode(ipdl.ast.Visitor):
def __init__(self, myside):
self.side = myside # "parent" or "child"
self.prettyside = myside.title()
self.clsname = None
self.protocol = None
self.hdrfile = None
self.cppfile = None
self.ns = None
self.cls = None
self.protocolCxxIncludes = []
self.actorForwardDecls = []
self.usingDecls = []
self.externalIncludes = set()
self.nonForwardDeclaredHeaders = set()
self.typedefSet = set(
[
Typedef(Type("mozilla::ipc::ActorHandle"), "ActorHandle"),
Typedef(Type("base::ProcessId"), "ProcessId"),
Typedef(Type("mozilla::ipc::ProtocolId"), "ProtocolId"),
Typedef(Type("mozilla::ipc::Endpoint"), "Endpoint", ["FooSide"]),
Typedef(
Type("mozilla::ipc::ManagedEndpoint"),
"ManagedEndpoint",
["FooSide"],
),
Typedef(Type("mozilla::UniquePtr"), "UniquePtr", ["T"]),
Typedef(
Type("mozilla::ipc::ResponseRejectReason"), "ResponseRejectReason"
),
]
)
def lower(self, tu, clsname, cxxHeaderFile, cxxFile):
self.clsname = clsname
self.hdrfile = cxxHeaderFile
self.cppfile = cxxFile
tu.accept(self)
def standardTypedefs(self):
return [
Typedef(Type("mozilla::ipc::IProtocol"), "IProtocol"),
Typedef(Type("IPC::Message"), "Message"),
Typedef(Type("base::ProcessHandle"), "ProcessHandle"),
Typedef(Type("mozilla::ipc::MessageChannel"), "MessageChannel"),
Typedef(Type("mozilla::ipc::SharedMemory"), "SharedMemory"),
]
def visitTranslationUnit(self, tu):
self.protocol = tu.protocol
hf = self.hdrfile
cf = self.cppfile
# make the C++ header
hf.addthings(
[_DISCLAIMER]
+ _includeGuardStart(hf)
+ [
Whitespace.NL,
CppDirective("include", '"' + _protocolHeaderName(tu.protocol) + '.h"'),
]
)
for inc in tu.includes:
inc.accept(self)
for inc in tu.cxxIncludes:
inc.accept(self)
for using in tu.builtinUsing:
using.accept(self)
for using in tu.using:
using.accept(self)
for su in tu.structsAndUnions:
su.accept(self)
# this generates the actor's full impl in self.cls
tu.protocol.accept(self)
clsdecl, clsdefn = _splitClassDeclDefn(self.cls)
# XXX damn C++ ... return types in the method defn aren't in
# class scope
for stmt in clsdefn.stmts:
if isinstance(stmt, MethodDefn):
if stmt.decl.ret and stmt.decl.ret.name == "Result":
stmt.decl.ret.name = clsdecl.name + "::" + stmt.decl.ret.name
def setToIncludes(s):
return [CppDirective("include", '"%s"' % i) for i in sorted(iter(s))]
def makeNamespace(p, file):
if 0 == len(p.namespaces):
return file
ns = Namespace(p.namespaces[-1].name)
outerns = _putInNamespaces(ns, p.namespaces[:-1])
file.addthing(outerns)
return ns
if len(self.nonForwardDeclaredHeaders) != 0:
self.hdrfile.addthings(
[
Whitespace("// Headers for things that cannot be forward declared"),
Whitespace.NL,
]
+ setToIncludes(self.nonForwardDeclaredHeaders)
+ [Whitespace.NL]
)
self.hdrfile.addthings(self.actorForwardDecls)
self.hdrfile.addthings(self.usingDecls)
hdrns = makeNamespace(self.protocol, self.hdrfile)
hdrns.addstmts(
[Whitespace.NL, Whitespace.NL, clsdecl, Whitespace.NL, Whitespace.NL]
)
actortype = ActorType(tu.protocol.decl.type)
traitsdecl, traitsdefn = _ParamTraits.actorPickling(actortype, self.side)
self.hdrfile.addthings([traitsdecl, Whitespace.NL] + _includeGuardEnd(hf))
# If the implementation type is not overridden, add an implicit import
# for the default implementation header file. Explicit implementation
# types will specify their headers manually with `include`.
if self.protocol.implAttribute(self.side) is None:
assert self.protocol.name.startswith("P")
self.externalIncludes.add(
"".join(n.name + "/" for n in self.protocol.namespaces)
+ self.protocol.name[1:]
+ self.side.capitalize()
+ ".h"
)
# make the .cpp file
cf.addthings(
[
_DISCLAIMER,
Whitespace.NL,
CppDirective(
"include",
'"' + _protocolHeaderName(self.protocol, self.side) + '.h"',
),
]
+ setToIncludes(self.externalIncludes)
)
cf.addthings(
(
[Whitespace.NL]
+ [
CppDirective("include", '"%s.h"' % (inc))
for inc in self.protocolCxxIncludes
]
+ [Whitespace.NL]
+ [
CppDirective("include", '"%s"' % filename)
for filename in ipdl.builtin.CppIncludes
]
+ [Whitespace.NL]
)
)
cppns = makeNamespace(self.protocol, cf)
cppns.addstmts(
[Whitespace.NL, Whitespace.NL, clsdefn, Whitespace.NL, Whitespace.NL]
)
cf.addthing(traitsdefn)
def visitUsingStmt(self, using):
if using.decl.fullname is not None:
self.typedefSet.add(
Typedef(Type(using.decl.fullname), using.decl.shortname)
)
if using.header is None:
return
if using.canBeForwardDeclared():
spec = using.type
self.usingDecls.extend(
[
_makeForwardDeclForQClass(
spec.baseid,
spec.quals,
cls=using.isClass(),
struct=using.isStruct(),
),
Whitespace.NL,
]
)
self.externalIncludes.add(using.header)
else:
self.nonForwardDeclaredHeaders.add(using.header)
def visitCxxInclude(self, inc):
self.externalIncludes.add(inc.file)
def visitInclude(self, inc):
if inc.tu.filetype == "header":
# Including a header will declare any globals defined by "using"
# statements into our scope. To serialize these, we also may need
# cxx include statements, so visit them as well.
for cxxinc in inc.tu.cxxIncludes:
cxxinc.accept(self)
for using in inc.tu.using:
using.accept(self)
for su in inc.tu.structsAndUnions:
su.accept(self)
else:
# Includes for protocols only include types explicitly exported by
# those protocols.
ip = inc.tu.protocol
if ip == self.protocol:
return
self.actorForwardDecls.extend(
[
_makeForwardDeclForActor(ip.decl.type, self.side),
_makeForwardDeclForActor(ip.decl.type, _otherSide(self.side)),
Whitespace.NL,
]
)
self.protocolCxxIncludes.append(_protocolHeaderName(ip, self.side))
if ip.decl.fullname is not None:
self.typedefSet.add(
Typedef(
Type(_actorName(ip.decl.fullname, self.side.title())),
_actorName(ip.decl.shortname, self.side.title()),
)
)
self.typedefSet.add(
Typedef(
Type(
_actorName(ip.decl.fullname, _otherSide(self.side).title())
),
_actorName(ip.decl.shortname, _otherSide(self.side).title()),
)
)
def visitStructDecl(self, sd):
if sd.decl.fullname is not None:
self.typedefSet.add(Typedef(Type(sd.fqClassName()), sd.name))
def visitUnionDecl(self, ud):
if ud.decl.fullname is not None:
self.typedefSet.add(Typedef(Type(ud.fqClassName()), ud.name))
def visitProtocol(self, p):
self.hdrfile.addcode(
"""
#ifdef DEBUG
#include "prenv.h"
#endif // DEBUG
#include "mozilla/Tainting.h"
#include "mozilla/ipc/MessageChannel.h"
#include "mozilla/ipc/ProtocolUtils.h"
"""
)
self.protocol = p
ptype = p.decl.type
toplevel = p.decl.type.toplevel()
hasAsyncReturns = False
for md in p.messageDecls:
if md.hasAsyncReturns():
hasAsyncReturns = True
break
if ptype.isToplevel():
inherits = [Inherit(Type("mozilla::ipc::IToplevelProtocol"))]
elif ptype.isRefcounted():
inherits = [Inherit(Type("mozilla::ipc::IRefCountedProtocol"))]
else:
inherits = [Inherit(Type("mozilla::ipc::IProtocol"))]
if ptype.isToplevel() and self.side == "parent":
self.hdrfile.addthings(
[_makeForwardDeclForQClass("nsIFile", []), Whitespace.NL]
)
self.cls = Class(self.clsname, inherits=inherits, abstract=True)
self.cls.addstmt(Label.PRIVATE)
friends = _FindFriends().findFriends(ptype)
if ptype.isManaged():
friends.update(ptype.managers)
# |friend| managed actors so that they can call our Dealloc*()
friends.update(ptype.manages)
# don't friend ourself if we're a self-managed protocol
friends.discard(ptype)
for friend in sorted(friends, key=lambda f: f.fullname()):
self.actorForwardDecls.extend(
[_makeForwardDeclForActor(friend, self.prettyside), Whitespace.NL]
)
self.cls.addstmt(
FriendClassDecl(_actorName(friend.fullname(), self.prettyside))
)
self.cls.addstmt(Label.PROTECTED)
for typedef in sorted(self.typedefSet):
self.cls.addstmt(typedef)
self.cls.addstmt(Whitespace.NL)
if hasAsyncReturns:
self.cls.addstmt(Label.PUBLIC)
for md in p.messageDecls:
if self.sendsMessage(md) and md.hasAsyncReturns():
self.cls.addstmt(
Typedef(_makePromise(md.returns, self.side), md.promiseName())
)
if self.receivesMessage(md) and md.hasAsyncReturns():
self.cls.addstmt(
Typedef(_makeResolver(md.returns, self.side), md.resolverName())
)
self.cls.addstmt(Whitespace.NL)
self.cls.addstmt(Label.PROTECTED)
# interface methods that the concrete subclass has to impl
for md in p.messageDecls:
isctor, isdtor = md.decl.type.isCtor(), md.decl.type.isDtor()
if self.receivesMessage(md):
# generate Recv/Answer* interface
implicit = not isdtor
returnsems = "resolver" if md.decl.type.isAsync() else "out"
recvDecl = MethodDecl(
md.recvMethod(),
params=md.makeCxxParams(
paramsems="move",
returnsems=returnsems,
side=self.side,
implicit=implicit,
direction="recv",
),
ret=Type("mozilla::ipc::IPCResult"),
methodspec=MethodSpec.VIRTUAL,
)
# These method implementations cause problems when trying to
# override them with different types in a direct call class.
#
# For the `isdtor` case there's a simple solution: it doesn't
# make much sense to specify arguments and then completely
# ignore them, and the no-arg case isn't a problem for
# overriding.
if isctor or (isdtor and not md.inParams):
defaultRecv = MethodDefn(recvDecl)
defaultRecv.addcode("return IPC_OK();\n")
self.cls.addstmt(defaultRecv)
elif self.protocol.implAttribute(self.side) == "virtual":
# If we're using virtual calls, we need the methods to be
# declared on the base class.
recvDecl.methodspec = MethodSpec.PURE
self.cls.addstmt(StmtDecl(recvDecl))
# If we're using virtual calls, we need the methods to be declared on
# the base class.
if self.protocol.implAttribute(self.side) == "virtual":
for md in p.messageDecls:
managed = md.decl.type.constructedType()
if not ptype.isManagerOf(managed) or md.decl.type.isDtor():
continue
# add the Alloc interface for managed actors
actortype = md.actorDecl().bareType(self.side)
if managed.isRefcounted():
if not self.receivesMessage(md):
continue
actortype.ptr = False
actortype = _alreadyaddrefed(actortype)
self.cls.addstmt(
StmtDecl(
MethodDecl(
_allocMethod(managed, self.side),
params=md.makeCxxParams(
side=self.side, implicit=False, direction="recv"
),
ret=actortype,
methodspec=MethodSpec.PURE,
)
)
)
# add the Dealloc interface for all managed non-refcounted actors,
# even without ctors. This is useful for protocols which use
# ManagedEndpoint for construction.
for managed in ptype.manages:
if managed.isRefcounted():
continue
self.cls.addstmt(
StmtDecl(
MethodDecl(
_deallocMethod(managed, self.side),
params=[
Decl(p.managedCxxType(managed, self.side), "aActor")
],
ret=Type.BOOL,
methodspec=MethodSpec.PURE,
)
)
)
if ptype.isToplevel():
# void ProcessingError(code); default to no-op
processingerror = MethodDefn(
MethodDecl(
p.processingErrorVar().name,
params=[
Param(_Result.Type(), "aCode"),
Param(Type("char", const=True, ptr=True), "aReason"),
],
methodspec=MethodSpec.OVERRIDE,
)
)
# bool ShouldContinueFromReplyTimeout(); default to |true|
shouldcontinue = MethodDefn(
MethodDecl(
p.shouldContinueFromTimeoutVar().name,
ret=Type.BOOL,
methodspec=MethodSpec.OVERRIDE,
)
)
shouldcontinue.addcode("return true;\n")
self.cls.addstmts(
[
processingerror,
shouldcontinue,
Whitespace.NL,
]
)
self.cls.addstmts(([Label.PUBLIC] + self.standardTypedefs() + [Whitespace.NL]))
self.cls.addstmt(Label.PUBLIC)
# Actor()
ctor = ConstructorDefn(ConstructorDecl(self.clsname))
side = ExprVar("mozilla::ipc::" + self.side.title() + "Side")
if ptype.isToplevel():
name = ExprLiteral.String(_actorName(p.name, self.side))
ctor.memberinits = [
ExprMemberInit(
ExprVar("mozilla::ipc::IToplevelProtocol"),
[name, _protocolId(ptype), side],
)
]
else:
baseCtor = (
ExprVar("mozilla::ipc::IRefCountedProtocol")
if ptype.isRefcounted()
else ExprVar("mozilla::ipc::IProtocol")
)
ctor.memberinits = [ExprMemberInit(baseCtor, [_protocolId(ptype), side])]
ctor.addcode("MOZ_COUNT_CTOR(${clsname});\n", clsname=self.clsname)
self.cls.addstmts([ctor, Whitespace.NL])
# ~Actor()
dtor = DestructorDefn(
DestructorDecl(self.clsname, methodspec=MethodSpec.VIRTUAL)
)
dtor.addcode("MOZ_COUNT_DTOR(${clsname});\n", clsname=self.clsname)
self.cls.addstmts([dtor, Whitespace.NL])
# ActorAlloc() and ActorDealloc()
actoralloc = MethodDefn(MethodDecl("ActorAlloc", methodspec=MethodSpec.FINAL))
actordealloc = MethodDefn(
MethodDecl("ActorDealloc", methodspec=MethodSpec.FINAL)
)
# Assert process type in ActorAlloc
procattr = p.procAttribute(self.side)
if procattr not in ("any", None):
if procattr == "anychild":
procattr_assertion = "!XRE_IsParentProcess()"
elif procattr == "anydom":
procattr_assertion = "XRE_IsParentProcess() || XRE_IsContentProcess()"
elif procattr == "compositor":
procattr_assertion = "XRE_IsParentProcess() || XRE_IsGPUProcess()"
else:
procattr_assertion = "XRE_Is%sProcess()" % procattr
actoralloc.addcode(
"MOZ_RELEASE_ASSERT(${assertion}, ${message});\n",
assertion=procattr_assertion,
message=ExprLiteral.String("Invalid process for `%s'" % self.clsname),
)
if ptype.isRefcounted():
# Perform AddRef/Release in ActorAlloc/ActorDealloc if refcounted.
actoralloc.addcode("AddRef();\n")
actordealloc.addcode("Release();\n")
elif not ptype.isToplevel():
# If we're a managed actor with [ManualDealloc], use DeallocManagee
# to invoke the relevant Dealloc method.
actordealloc.addcode(
"""
if (Manager()) {
Manager()->DeallocManagee(${protocolId}, this);
}
""",
protocolId=_protocolId(ptype),
)
self.cls.addstmts([Label.PROTECTED, actoralloc, actordealloc])
self.cls.addstmt(Label.PUBLIC)
if ptype.hasOtherPid():
otherpidmeth = MethodDefn(
MethodDecl("OtherPid", ret=Type("::base::ProcessId"), const=True)
)
otherpidmeth.addcode(
"""
::base::ProcessId pid =
::mozilla::ipc::IProtocol::ToplevelProtocol()->OtherPidMaybeInvalid();
MOZ_RELEASE_ASSERT(pid != ::base::kInvalidProcessId);
return pid;
"""
)
self.cls.addstmts([otherpidmeth, Whitespace.NL])
if not ptype.isToplevel():
if 1 == len(p.managers):
# manager() const
managertype = p.managerActorType(self.side, ptr=True)
managermeth = MethodDefn(
MethodDecl("Manager", ret=managertype, const=True)
)
managermeth.addcode(
"""
return static_cast<${type}>(IProtocol::Manager());
""",
type=managertype,
)
self.cls.addstmts([managermeth, Whitespace.NL])
def actorFromIter(itervar):
return ExprCode("${iter}.Get()->GetKey()", iter=itervar)
def forLoopOverHashtable(hashtable, itervar, const=False):
itermeth = "ConstIter" if const else "Iter"
return StmtFor(
init=ExprCode(
"auto ${itervar} = ${hashtable}.${itermeth}()",
itervar=itervar,
hashtable=hashtable,
itermeth=itermeth,
),
cond=ExprCode("!${itervar}.Done()", itervar=itervar),
update=ExprCode("${itervar}.Next()", itervar=itervar),
)
# Managed[T](Array& inout) const
# const Array<T>& Managed() const
for managed in ptype.manages:
container = p.managedVar(managed, self.side)
meth = MethodDefn(
MethodDecl(
p.managedMethod(managed, self.side).name,
params=[
Decl(
_cxxArrayType(
p.managedCxxType(managed, self.side), ref=True
),
"aArr",
)
],
const=True,
)
)
meth.addcode("${container}.ToArray(aArr);\n", container=container)
refmeth = MethodDefn(
MethodDecl(
p.managedMethod(managed, self.side).name,
params=[],
ret=p.managedVarType(managed, self.side, const=True, ref=True),
const=True,
)
)
refmeth.addcode("return ${container};\n", container=container)
self.cls.addstmts([meth, refmeth, Whitespace.NL])
# AllManagedActorsCount() const
managedcount = MethodDefn(
MethodDecl(
"AllManagedActorsCount",
ret=Type.UINT32,
methodspec=MethodSpec.OVERRIDE,
const=True,
)
)
# Count the number of managed actors.
managedcount.addcode(
"""
uint32_t total = 0;
"""
)
for managed in ptype.manages:
managedcount.addcode(
"""
total += ${container}.Count();
""",
container=p.managedVar(managed, self.side),
)
managedcount.addcode(
"""
return total;
"""
)
self.cls.addstmts([managedcount, Whitespace.NL])
# OpenPEndpoint(...)/BindPEndpoint(...)
for managed in ptype.manages:
self.genManagedEndpoint(managed)
# OnMessageReceived()
# save these away for use in message handler case stmts
msgvar = ExprVar("msg__")
self.msgvar = msgvar
replyvar = ExprVar("reply__")
self.replyvar = replyvar
var = ExprVar("v__")
self.var = var
# for ctor recv cases, we can't read the actor ID into a PFoo*
# because it doesn't exist on this side yet. Use a "special"
# actor handle instead
handlevar = ExprVar("handle__")
self.handlevar = handlevar
msgtype = ExprCode("msg__.type()")
self.asyncSwitch = StmtSwitch(msgtype)
self.syncSwitch = None
if toplevel.isSync():
self.syncSwitch = StmtSwitch(msgtype)
# Add a handler for the MANAGED_ENDPOINT_BOUND and
# MANAGED_ENDPOINT_DROPPED message types for managed actors.
if not ptype.isToplevel():
clearawaitingmanagedendpointbind = """
if (!mAwaitingManagedEndpointBind) {
NS_WARNING("Unexpected managed endpoint lifecycle message after actor bound!");
return MsgNotAllowed;
}
mAwaitingManagedEndpointBind = false;
"""
self.asyncSwitch.addcase(
CaseLabel("MANAGED_ENDPOINT_BOUND_MESSAGE_TYPE"),
StmtBlock(
[
StmtCode(clearawaitingmanagedendpointbind),
StmtReturn(_Result.Processed),
]
),
)
self.asyncSwitch.addcase(
CaseLabel("MANAGED_ENDPOINT_DROPPED_MESSAGE_TYPE"),
StmtBlock(
[
StmtCode(clearawaitingmanagedendpointbind),
*self.destroyActor(
None,
ExprVar.THIS,
why=_DestroyReason.ManagedEndpointDropped,
),
StmtReturn(_Result.Processed),
]
),
)
# implement Send*() methods and add dispatcher cases to
# message switch()es
for md in p.messageDecls:
self.visitMessageDecl(md)
# add default cases
default = StmtCode(
"""
return MsgNotKnown;
"""
)
self.asyncSwitch.addcase(DefaultLabel(), default)
if toplevel.isSync():
self.syncSwitch.addcase(DefaultLabel(), default)
self.cls.addstmts(self.implementManagerIface())
def makeHandlerMethod(name, switch, hasReply, dispatches=False):
params = [Decl(Type("Message", const=True, ref=True), msgvar.name)]
if hasReply:
params.append(Decl(Type("UniquePtr<Message>", ref=True), replyvar.name))
method = MethodDefn(
MethodDecl(
name,
methodspec=MethodSpec.OVERRIDE,
params=params,
ret=_Result.Type(),
)
)
if not switch:
method.addcode(
"""
MOZ_ASSERT_UNREACHABLE("message protocol not supported");
return MsgNotKnown;
"""
)
return method
if dispatches:
if hasReply:
ondeadactor = [StmtReturn(_Result.RouteError)]
else:
ondeadactor = [
self.logMessage(
None, ExprAddrOf(msgvar), "Ignored message for dead actor"
),
StmtReturn(_Result.Processed),
]
method.addcode(
"""
int32_t route__ = ${msgvar}.routing_id();
if (MSG_ROUTING_CONTROL != route__) {
IProtocol* routed__ = Lookup(route__);
if (!routed__ || !routed__->GetLifecycleProxy()) {
$*{ondeadactor}
}
RefPtr<mozilla::ipc::ActorLifecycleProxy> proxy__ =
routed__->GetLifecycleProxy();
return proxy__->Get()->${name}($,{args});
}
""",
msgvar=msgvar,
ondeadactor=ondeadactor,
name=name,
args=[p.name for p in params],
)
# bug 509581: don't generate the switch stmt if there
# is only the default case; MSVC doesn't like that
if switch.nr_cases > 1:
method.addstmt(switch)
else:
method.addstmt(StmtReturn(_Result.NotKnown))
return method
dispatches = ptype.isToplevel() and ptype.isManager()
self.cls.addstmts(
[
makeHandlerMethod(
"OnMessageReceived",
self.asyncSwitch,
hasReply=False,
dispatches=dispatches,
),
Whitespace.NL,
]
)
self.cls.addstmts(
[
makeHandlerMethod(
"OnMessageReceived",
self.syncSwitch,
hasReply=True,
dispatches=dispatches,
),
Whitespace.NL,
]
)
# DoomSubtree()
doomsubtree = MethodDefn(
MethodDecl("DoomSubtree", methodspec=MethodSpec.OVERRIDE)
)
for managed in ptype.manages:
doomsubtree.addcode(
"""
for (auto* key : ${container}) {
key->DoomSubtree();
}
""",
container=p.managedVar(managed, self.side),
)
doomsubtree.addcode("SetDoomed();\n")
self.cls.addstmts([doomsubtree, Whitespace.NL])
# IProtocol* PeekManagedActor() override
peekmanagedactor = MethodDefn(
MethodDecl(
"PeekManagedActor",
methodspec=MethodSpec.OVERRIDE,
ret=Type("IProtocol", ptr=True),
)
)
for managed in ptype.manages:
peekmanagedactor.addcode(
"""
if (IProtocol* actor = ${container}.Peek()) {
return actor;
}
""",
container=p.managedVar(managed, self.side),
)
peekmanagedactor.addcode("return nullptr;\n")
self.cls.addstmts([peekmanagedactor, Whitespace.NL])
# private methods
self.cls.addstmt(Label.PRIVATE)
if not ptype.isToplevel():
self.cls.addstmts(
[
StmtDecl(
Decl(Type.BOOL, "mAwaitingManagedEndpointBind"),
init=ExprLiteral.FALSE,
),
Whitespace.NL,
]
)
for managed in ptype.manages:
self.cls.addstmts(
[
StmtDecl(
Decl(
p.managedVarType(managed, self.side),
p.managedVar(managed, self.side).name,
)
)
]
)
def genManagedEndpoint(self, managed):
hereEp = "ManagedEndpoint<%s>" % _actorName(managed.name(), self.side)
thereEp = "ManagedEndpoint<%s>" % _actorName(
managed.name(), _otherSide(self.side)
)
actor = _HybridDecl(ipdl.type.ActorType(managed), "aActor")
# ManagedEndpoint<PThere> OpenPEndpoint(PHere* aActor)
openmeth = MethodDefn(
MethodDecl(
"Open%sEndpoint" % managed.name(),
params=[
Decl(self.protocol.managedCxxType(managed, self.side), actor.name)
],
ret=Type(thereEp),
)
)
openmeth.addcode(
"""
$*{bind}
// Mark our actor as awaiting the other side to be bound. This will
// be cleared when a `MANAGED_ENDPOINT_{DROPPED,BOUND}` message is
// received.
aActor->mAwaitingManagedEndpointBind = true;
return ${thereEp}(mozilla::ipc::PrivateIPDLInterface(), aActor);
""",
bind=self.bindManagedActor(actor, errfn=ExprCall(ExprVar(thereEp))),
thereEp=thereEp,
)
# void BindPEndpoint(ManagedEndpoint<PHere>&& aEndpoint, PHere* aActor)
bindmeth = MethodDefn(
MethodDecl(
"Bind%sEndpoint" % managed.name(),
params=[
Decl(Type(hereEp), "aEndpoint"),
Decl(self.protocol.managedCxxType(managed, self.side), actor.name),
],
ret=Type.BOOL,
)
)
bindmeth.addcode(
"""
return aEndpoint.Bind(mozilla::ipc::PrivateIPDLInterface(), aActor, this, ${container});
""",
container=self.protocol.managedVar(managed, self.side),
)
self.cls.addstmts([openmeth, bindmeth, Whitespace.NL])
def implementManagerIface(self):
p = self.protocol
protocolbase = Type("IProtocol", ptr=True)
methods = []
if p.decl.type.isToplevel():
# FIXME: This used to be declared conditionally based on whether
# shmem appeared somewhere in the protocol hierarchy, however that
# caused issues due to Shmem instances hidden within custom C++
# types.
self.asyncSwitch.addcase(
CaseLabel("SHMEM_CREATED_MESSAGE_TYPE"),
self.genShmemCreatedHandler(),
)
self.asyncSwitch.addcase(
CaseLabel("SHMEM_DESTROYED_MESSAGE_TYPE"),
self.genShmemDestroyedHandler(),
)
# Keep track of types created with an INOUT ctor. We need to call
# Register() or RegisterID() for them depending on the side the managee
# is created.
inoutCtorTypes = []
for msg in p.messageDecls:
msgtype = msg.decl.type
if msgtype.isCtor() and msgtype.isInout():
inoutCtorTypes.append(msgtype.constructedType())
# all protocols share the "same" RemoveManagee() implementation
pvar = ExprVar("aProtocolId")
listenervar = ExprVar("aListener")
removemanagee = MethodDefn(
MethodDecl(
p.removeManageeMethod().name,
params=[
Decl(_protocolIdType(), pvar.name),
Decl(protocolbase, listenervar.name),
],
methodspec=MethodSpec.OVERRIDE,
)
)
if not len(p.managesStmts):
removemanagee.addcode(
"""
FatalError("unreached");
return;
"""
)
else:
switchontype = StmtSwitch(pvar)
for managee in p.managesStmts:
manageeipdltype = managee.decl.type
manageecxxtype = _cxxBareType(
ipdl.type.ActorType(manageeipdltype), self.side
)
case = ExprCode(
"""
${container}.EnsureRemoved(static_cast<${manageecxxtype}>(aListener));
return;
""",
manageecxxtype=manageecxxtype,
container=p.managedVar(manageeipdltype, self.side),
)
switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case)
switchontype.addcase(
DefaultLabel(),
ExprCode(
"""
FatalError("unreached");
return;
"""
),
)
removemanagee.addstmt(switchontype)
# The `DeallocManagee` method is called for managed actors to trigger
# deallocation when ActorLifecycleProxy is freed.
deallocmanagee = MethodDefn(
MethodDecl(
p.deallocManageeMethod().name,
params=[
Decl(_protocolIdType(), pvar.name),
Decl(protocolbase, listenervar.name),
],
methodspec=MethodSpec.OVERRIDE,
)
)
if not len(p.managesStmts):
deallocmanagee.addcode(
"""
FatalError("unreached");
return;
"""
)
else:
switchontype = StmtSwitch(pvar)
for managee in p.managesStmts:
manageeipdltype = managee.decl.type
# Reference counted actor types don't have corresponding
# `Dealloc` methods, as they are deallocated by releasing the
# IPDL-held reference.
if manageeipdltype.isRefcounted():
continue
case = StmtCode(
"""
${concrete}->${dealloc}(static_cast<${type}>(aListener));
return;
""",
concrete=self.concreteThis(),
dealloc=_deallocMethod(manageeipdltype, self.side),
type=_cxxBareType(ipdl.type.ActorType(manageeipdltype), self.side),
)
switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case)
switchontype.addcase(
DefaultLabel(),
StmtCode(
"""
FatalError("unreached");
return;
"""
),
)
deallocmanagee.addstmt(switchontype)
return methods + [removemanagee, deallocmanagee, Whitespace.NL]
def genShmemCreatedHandler(self):
assert self.protocol.decl.type.isToplevel()
return StmtCode(
"""
{
if (!ShmemCreated(${msgvar})) {
return MsgPayloadError;
}
return MsgProcessed;
}
""",
msgvar=self.msgvar,
)
def genShmemDestroyedHandler(self):
assert self.protocol.decl.type.isToplevel()
return StmtCode(
"""
{
if (!ShmemDestroyed(${msgvar})) {
return MsgPayloadError;
}
return MsgProcessed;
}
""",
msgvar=self.msgvar,
)
# -------------------------------------------------------------------------
# The next few functions are the crux of the IPDL code generator.
# They generate code for all the nasty work of message
# serialization/deserialization and dispatching handlers for
# received messages.
##
def concreteThis(self):
implAttr = self.protocol.implAttribute(self.side)
if implAttr == "virtual":
return ExprVar.THIS
if implAttr is None:
assert self.protocol.name.startswith("P")
className = self.protocol.name[1:] + self.side.capitalize()
else:
assert isinstance(implAttr, ipdl.ast.StringLiteral)
className = implAttr.value
return ExprCode("static_cast<${className}*>(this)", className=className)
def thisCall(self, function, args):
return ExprCall(ExprSelect(self.concreteThis(), "->", function), args=args)
def visitMessageDecl(self, md):
isctor = md.decl.type.isCtor()
isdtor = md.decl.type.isDtor()
decltype = md.decl.type
sendmethod = None
movesendmethod = None
promisesendmethod = None
recvlbl, recvcase = None, None
def addRecvCase(lbl, case):
if decltype.isAsync():
self.asyncSwitch.addcase(lbl, case)
elif decltype.isSync():
self.syncSwitch.addcase(lbl, case)
else:
assert 0
if self.sendsMessage(md):
isasync = decltype.isAsync()
# NOTE: Don't generate helper ctors for refcounted types.
#
# Safety concerns around providing your own actor to a ctor (namely
# that the return value won't be checked, and the argument will be
# `delete`-ed) are less critical with refcounted actors, due to the
# actor being held alive by the callsite.
#
# This allows refcounted actors to not implement crashing AllocPFoo
# methods on the sending side.
if isctor and not md.decl.type.constructedType().isRefcounted():
self.cls.addstmts([self.genHelperCtor(md), Whitespace.NL])
if isctor and isasync:
sendmethod, (recvlbl, recvcase) = self.genAsyncCtor(md)
elif isctor:
sendmethod = self.genBlockingCtorMethod(md)
elif isdtor:
assert isasync
sendmethod, (recvlbl, recvcase) = self.genDtor(md)
elif isasync:
(
sendmethod,
movesendmethod,
promisesendmethod,
(recvlbl, recvcase),
) = self.genAsyncSendMethod(md)
else:
sendmethod, movesendmethod = self.genBlockingSendMethod(md)
# XXX figure out what to do here
if isdtor and md.decl.type.constructedType().isToplevel():
sendmethod = None
if sendmethod is not None:
self.cls.addstmts([sendmethod, Whitespace.NL])
if movesendmethod is not None:
self.cls.addstmts([movesendmethod, Whitespace.NL])
if promisesendmethod is not None:
self.cls.addstmts([promisesendmethod, Whitespace.NL])
if recvcase is not None:
addRecvCase(recvlbl, recvcase)
recvlbl, recvcase = None, None
if self.receivesMessage(md):
if isctor:
recvlbl, recvcase = self.genCtorRecvCase(md)
elif isdtor:
recvlbl, recvcase = self.genDtorRecvCase(md)
else:
recvlbl, recvcase = self.genRecvCase(md)
# XXX figure out what to do here
if isdtor and md.decl.type.constructedType().isToplevel():
return
addRecvCase(recvlbl, recvcase)
def genAsyncCtor(self, md):
actor = md.actorDecl()
method = MethodDefn(self.makeSendMethodDecl(md))
msgvar, stmts = self.makeMessage(md, errfnSendCtor)
sendok, sendstmts = self.sendAsync(md, msgvar)
method.addcode(
"""
$*{bind}
// Build our constructor message.
$*{stmts}
// Notify the other side about the newly created actor. This can
// fail if our manager has already been destroyed.
//
// NOTE: If the send call fails due to toplevel channel teardown,
// the `IProtocol::ChannelSend` wrapper absorbs the error for us,
// so we don't tear down actors unexpectedly.
$*{sendstmts}
// Warn, destroy the actor, and return null if the message failed to
// send. Otherwise, return the successfully created actor reference.
if (!${sendok}) {
NS_WARNING("Error sending ${actorname} constructor");
$*{destroy}
return nullptr;
}
return ${actor};
""",
bind=self.bindManagedActor(actor),
stmts=stmts,
sendstmts=sendstmts,
sendok=sendok,
destroy=self.destroyActor(
md, actor.var(), why=_DestroyReason.FailedConstructor
),
actor=actor.var(),
actorname=actor.ipdltype.protocol.name() + self.side.capitalize(),
)
lbl = CaseLabel(md.pqReplyId())
case = StmtBlock()
case.addstmt(StmtReturn(_Result.Processed))
# TODO not really sure what to do with async ctor "replies" yet.
# destroy actor if there was an error? tricky ...
return method, (lbl, case)
def genBlockingCtorMethod(self, md):
actor = md.actorDecl()
method = MethodDefn(self.makeSendMethodDecl(md))
msgvar, stmts = self.makeMessage(md, errfnSendCtor)
replyvar = self.replyvar
sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
replystmts = self.deserializeReply(
md,
replyvar,
self.side,
errfnSendCtor,
errfnSentinel(ExprLiteral.NULL),
)
method.addcode(
"""
$*{bind}
// Build our constructor message.
$*{stmts}
// Synchronously send the constructor message to the other side. If
// the send fails, e.g. due to the remote side shutting down, the
// actor will be destroyed and potentially freed.
UniquePtr<Message> ${replyvar};
$*{sendstmts}
if (!(${sendok})) {
// Warn, destroy the actor and return null if the message
// failed to send.
NS_WARNING("Error sending constructor");
$*{destroy}
return nullptr;
}
$*{replystmts}
return ${actor};
""",
bind=self.bindManagedActor(actor),
stmts=stmts,
replyvar=replyvar,
sendstmts=sendstmts,
sendok=sendok,
destroy=self.destroyActor(
md, actor.var(), why=_DestroyReason.FailedConstructor
),
replystmts=replystmts,
actor=actor.var(),
actorname=actor.ipdltype.protocol.name() + self.side.capitalize(),
)
return method
def bindManagedActor(self, actordecl, errfn=ExprLiteral.NULL, idexpr=None):
actorproto = actordecl.ipdltype.protocol
if idexpr is None:
setManagerArgs = [ExprVar.THIS]
else:
setManagerArgs = [ExprVar.THIS, idexpr]
return [
StmtCode(
"""
if (!${actor}) {
NS_WARNING("Cannot bind null ${actorname} actor");
return ${errfn};
}
if (${actor}->SetManagerAndRegister($,{setManagerArgs})) {
${container}.Insert(${actor});
} else {
NS_WARNING("Failed to bind ${actorname} actor");
return ${errfn};
}
""",
actor=actordecl.var(),
actorname=actorproto.name() + self.side.capitalize(),
errfn=errfn,
setManagerArgs=setManagerArgs,
container=self.protocol.managedVar(actorproto, self.side),
)
]
def genHelperCtor(self, md):
helperdecl = self.makeSendMethodDecl(md)
helperdecl.params = helperdecl.params[1:]
helper = MethodDefn(helperdecl)
helper.addstmts(
[
self.callAllocActor(md, retsems="out", side=self.side),
StmtReturn(
ExprCall(
ExprVar(helperdecl.name), args=md.makeCxxArgs(paramsems="move")
)
),
]
)
return helper
def genDtor(self, md):
actorvar = ExprVar("actor")
method = MethodDefn(self.makeDtorMethodDecl(md, actorvar))
method.addstmt(self.dtorPrologue(actorvar))
msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar)
sendok, sendstmts = self.sendAsync(md, msgvar, actorvar)
method.addstmts(
stmts
+ sendstmts
+ [Whitespace.NL]
+ self.dtorEpilogue(md, actorvar)
+ [StmtReturn(sendok)]
)
lbl = CaseLabel(md.pqReplyId())
case = StmtBlock()
case.addstmt(StmtReturn(_Result.Processed))
# TODO if the dtor is "inherently racy", keep the actor alive
# until the other side acks
return method, (lbl, case)
def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion):
return [
StmtCode(
"""
${actor}->ActorDisconnected(${why});
""",
actor=actorexpr,
why=why,
)
]
def dtorPrologue(self, actorexpr):
return StmtCode(
"""
if (!${actor} || !${actor}->CanSend()) {
NS_WARNING("Attempt to __delete__ missing or closed actor");
return false;
}
""",
actor=actorexpr,
)
def dtorEpilogue(self, md, actorexpr):
return self.destroyActor(md, actorexpr)
def genRecvAsyncReplyCase(self, md):
lbl = CaseLabel(md.pqReplyId())
case = StmtBlock()
resolve, reason, prologue, desrej, desstmts = self.deserializeAsyncReply(
md, self.side, errfnRecv, errfnSentinel(_Result.ValuError)
)
if len(md.returns) > 1:
resolvetype = _tuple([d.bareType(self.side) for d in md.returns])
resolvearg = ExprCall(
ExprVar("std::make_tuple"), args=[ExprMove(p.var()) for p in md.returns]
)
else:
resolvetype = md.returns[0].bareType(self.side)
resolvearg = ExprMove(md.returns[0].var())
case.addcode(
"""
$*{prologue}
UniquePtr<MessageChannel::UntypedCallbackHolder> untypedCallback =
GetIPCChannel()->PopCallback(${msgvar}, Id());
typedef MessageChannel::CallbackHolder<${resolvetype}> CallbackHolder;
auto* callback = static_cast<CallbackHolder*>(untypedCallback.get());
if (!callback) {
FatalError("Error unknown callback");
return MsgProcessingError;
}
if (${resolve}) {
$*{desstmts}
callback->Resolve(${resolvearg});
} else {
$*{desrej}
callback->Reject(std::move(${reason}));
}
return MsgProcessed;
""",
prologue=prologue,
msgvar=self.msgvar,
resolve=resolve,
resolvetype=resolvetype,
desstmts=desstmts,
resolvearg=resolvearg,
desrej=desrej,
reason=reason,
)
return (lbl, case)
def genAsyncSendMethod(self, md):
decl = self.makeSendMethodDecl(md)
if "VirtualSendImpl" in md.attributes:
decl.methodspec = MethodSpec.VIRTUAL
method = MethodDefn(decl)
msgvar, stmts = self.makeMessage(md, errfnSend)
retvar, sendstmts = self.sendAsync(md, msgvar)
method.addstmts(stmts + [Whitespace.NL] + sendstmts + [StmtReturn(retvar)])
movemethod = None
# Add the promise overload if we need one.
if md.returns:
decl = self.makeSendMethodDecl(md, promise=True)
if "VirtualSendImpl" in md.attributes:
decl.methodspec = MethodSpec.VIRTUAL
promisemethod = MethodDefn(decl)
stmts = self.sendAsyncWithPromise(md)
promisemethod.addstmts(stmts)
(lbl, case) = self.genRecvAsyncReplyCase(md)
else:
(promisemethod, lbl, case) = (None, None, None)
return method, movemethod, promisemethod, (lbl, case)
def genBlockingSendMethod(self, md):
method = MethodDefn(self.makeSendMethodDecl(md))
msgvar, serstmts = self.makeMessage(md, errfnSend)
replyvar = self.replyvar
sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
failif = StmtIf(ExprNot(sendok))
failif.addifstmt(StmtReturn.FALSE)
desstmts = self.deserializeReply(
md, replyvar, self.side, errfnSend, errfnSentinel()
)
method.addstmts(
serstmts
+ [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr<Message>"), replyvar.name))]
+ sendstmts
+ [failif]
+ desstmts
+ [Whitespace.NL, StmtReturn.TRUE]
)
movemethod = None
return method, movemethod
def genCtorRecvCase(self, md):
lbl = CaseLabel(md.pqMsgId())
case = StmtBlock()
actorhandle = self.handlevar
stmts = self.deserializeMessage(
md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
)
idvar, saveIdStmts = self.saveActorId(md)
case.addstmts(
stmts
+ [
StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
for r in md.returns
]
# alloc the actor, register it under the foreign ID
+ [self.callAllocActor(md, retsems="in", side=self.side)]
+ self.bindManagedActor(
md.actorDecl(), errfn=_Result.ValuError, idexpr=_actorHId(actorhandle)
)
+ [Whitespace.NL]
+ saveIdStmts
+ self.invokeRecvHandler(md)
+ self.makeReply(md, errfnRecv, idvar)
+ [Whitespace.NL, StmtReturn(_Result.Processed)]
)
return lbl, case
def genDtorRecvCase(self, md):
lbl = CaseLabel(md.pqMsgId())
case = StmtBlock()
stmts = self.deserializeMessage(
md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
)
idvar, saveIdStmts = self.saveActorId(md)
case.addstmts(
stmts
+ [
StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
for r in md.returns
]
+ self.invokeRecvHandler(md)
+ [Whitespace.NL]
+ saveIdStmts
+ self.makeReply(md, errfnRecv, routingId=idvar)
+ [Whitespace.NL]
+ self.dtorEpilogue(md, ExprVar.THIS)
+ [Whitespace.NL, StmtReturn(_Result.Processed)]
)
return lbl, case
def genRecvCase(self, md):
lbl = CaseLabel(md.pqMsgId())
case = StmtBlock()
stmts = self.deserializeMessage(
md, self.side, errfn=errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
)
idvar, saveIdStmts = self.saveActorId(md)
declstmts = [
StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
for r in md.returns
]
if md.decl.type.isAsync() and md.returns:
declstmts = self.makeResolver(md, errfnRecv, routingId=idvar)
case.addstmts(
stmts
+ saveIdStmts
+ declstmts
+ self.invokeRecvHandler(md)
+ [Whitespace.NL]
+ self.makeReply(md, errfnRecv, routingId=idvar)
+ [StmtReturn(_Result.Processed)]
)
return lbl, case
# helper methods
def makeMessage(self, md, errfn, fromActor=None):
msgvar = self.msgvar
writervar = ExprVar("writer__")
routingId = self.protocol.routingId(fromActor)
this = fromActor or ExprVar.THIS
stmts = (
[
StmtDecl(
Decl(Type("UniquePtr<IPC::Message>"), msgvar.name),
init=ExprCall(ExprVar(md.pqMsgCtorFunc()), args=[routingId]),
),
StmtDecl(
Decl(Type("IPC::MessageWriter"), writervar.name),
initargs=[ExprDeref(msgvar), this],
),
]
+ [Whitespace.NL]
+ [
_ParamTraits.checkedWrite(
p.ipdltype,
p.var(),
ExprAddrOf(writervar),
sentinelKey=p.name,
)
for p in md.params
]
+ [Whitespace.NL]
+ self.setMessageFlags(md, msgvar)
)
return msgvar, stmts
def makeResolver(self, md, errfn, routingId):
if routingId is None:
routingId = self.protocol.routingId()
if not md.decl.type.isAsync() or not md.hasReply():
return []
def paramValue(idx):
assert idx < len(md.returns)
if len(md.returns) > 1:
return ExprCode("std::get<${idx}>(aParam)", idx=idx)
return ExprVar("aParam")
serializeParams = [
_ParamTraits.checkedWrite(
p.ipdltype,
paramValue(idx),
ExprAddrOf(ExprVar("writer__")),
sentinelKey=p.name,
)
for idx, p in enumerate(md.returns)
]
return [
StmtCode(
"""
UniquePtr<IPC::Message> ${replyvar}(${replyCtor}(${routingId}));
${replyvar}->set_seqno(${msgvar}.seqno());
RefPtr<mozilla::ipc::IPDLResolverInner> resolver__ =
new mozilla::ipc::IPDLResolverInner(std::move(${replyvar}), this);
${resolvertype} resolver = [resolver__ = std::move(resolver__)](${resolveType} aParam) {
resolver__->Resolve([&] (IPC::Message* ${replyvar}, IProtocol* self__) {
IPC::MessageWriter writer__(*${replyvar}, self__);
$*{serializeParams}
${logSendingReply}
});
};
""",
msgvar=self.msgvar,
resolvertype=Type(md.resolverName()),
routingId=routingId,
resolveType=_resolveType(md.returns, self.side),
replyvar=self.replyvar,
replyCtor=ExprVar(md.pqReplyCtorFunc()),
serializeParams=serializeParams,
logSendingReply=self.logMessage(
md,
self.replyvar,
"Sending reply ",
actor=ExprVar("self__"),
),
)
]
def makeReply(self, md, errfn, routingId):
if routingId is None:
routingId = self.protocol.routingId()
# TODO special cases for async ctor/dtor replies
if not md.decl.type.hasReply():
return []
if md.decl.type.isAsync() and md.decl.type.hasReply():
return []
replyvar = self.replyvar
return (
[
StmtExpr(
ExprAssn(
replyvar,
ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[routingId]),
)
),
StmtDecl(
Decl(Type("IPC::MessageWriter"), "writer__"),
initargs=[ExprDeref(replyvar), ExprVar.THIS],
),
Whitespace.NL,
]
+ [
_ParamTraits.checkedWrite(
r.ipdltype,
r.var(),
ExprAddrOf(ExprVar("writer__")),
sentinelKey=r.name,
)
for r in md.returns
]
+ self.setMessageFlags(md, replyvar)
+ [self.logMessage(md, replyvar, "Sending reply ")]
)
def setMessageFlags(self, md, var, seqno=None):
stmts = []
if seqno:
stmts.append(
StmtExpr(ExprCall(ExprSelect(var, "->", "set_seqno"), args=[seqno]))
)
return stmts + [Whitespace.NL]
def deserializeMessage(self, md, side, errfn, errfnSent):
msgvar = self.msgvar
msgexpr = ExprAddrOf(msgvar)
readervar = ExprVar("reader__")
isctor = md.decl.type.isCtor()
stmts = [
self.logMessage(md, msgexpr, "Received ", receiving=True),
self.profilerLabel(md),
Whitespace.NL,
]
if 0 == len(md.params):
return stmts
start, reads = 0, []
if isctor:
# return the raw actor handle so that its ID can be used
# to construct the "real" actor
handlevar = self.handlevar
handletype = Type("ActorHandle")
reads = [
_ParamTraits.checkedRead(
None,
handletype,
handlevar,
ExprAddrOf(readervar),
errfn,
"'%s'" % handletype.name,
sentinelKey="actor",
errfnSentinel=errfnSent,
)
]
start = 1
def maybeTainted(p, side):
if md.decl.type.tainted and "NoTaint" not in p.attributes:
return Type("Tainted", T=p.bareType(side))
return p.bareType(side)
reads.extend(
[
_ParamTraits.checkedRead(
p.ipdltype,
maybeTainted(p, side),
p.var(),
ExprAddrOf(readervar),
errfn,
"'%s'" % p.ipdltype.name(),
sentinelKey=p.name,
errfnSentinel=errfnSent,
)
for p in md.params[start:]
]
)
stmts.extend(
(
[
StmtDecl(
Decl(Type("IPC::MessageReader"), readervar.name),
initargs=[msgvar, ExprVar.THIS],
)
]
+ [Whitespace.NL]
+ reads
+ [StmtCode("${reader}.EndRead();\n", reader=readervar)]
)
)
return stmts
def deserializeAsyncReply(self, md, side, errfn, errfnSent):
msgvar = self.msgvar
readervar = ExprVar("reader__")
msgexpr = ExprAddrOf(msgvar)
isctor = md.decl.type.isCtor()
resolve = ExprVar("resolve__")
reason = ExprVar("reason__")
# NOTE: The `resolve__` and `reason__` parameters don't have sentinels,
# as they are serialized by the IPDLResolverInner type in
# ProtocolUtils.cpp rather than by generated code.
desresolve = [
StmtCode(
"""
bool resolve__ = false;
if (!IPC::ReadParam(&${readervar}, &resolve__)) {
FatalError("Error deserializing bool");
return MsgValueError;
}
""",
readervar=readervar,
),
]
desrej = [
StmtCode(
"""
ResponseRejectReason reason__{};
if (!IPC::ReadParam(&${readervar}, &reason__)) {
FatalError("Error deserializing ResponseRejectReason");
return MsgValueError;
}
${readervar}.EndRead();
""",
readervar=readervar,
),
]
prologue = [
self.logMessage(md, msgexpr, "Received ", receiving=True),
self.profilerLabel(md),
Whitespace.NL,
]
if not md.returns:
return prologue
prologue.extend(
[
StmtDecl(
Decl(Type("IPC::MessageReader"), readervar.name),
initargs=[msgvar, ExprVar.THIS],
)
]
+ desresolve
)
start, reads = 0, []
if isctor:
# return the raw actor handle so that its ID can be used
# to construct the "real" actor
handlevar = self.handlevar
handletype = Type("ActorHandle")
reads = [
_ParamTraits.checkedRead(
None,
handletype,
handlevar,
ExprAddrOf(readervar),
errfn,
"'%s'" % handletype.name,
sentinelKey="actor",
errfnSentinel=errfnSent,
)
]
start = 1
stmts = (
reads
+ [
_ParamTraits.checkedRead(
p.ipdltype,
p.bareType(side),
p.var(),
ExprAddrOf(readervar),
errfn,
"'%s'" % p.ipdltype.name(),
sentinelKey=p.name,
errfnSentinel=errfnSent,
)
for p in md.returns[start:]
]
+ [StmtCode("${reader}.EndRead();", reader=readervar)]
)
return resolve, reason, prologue, desrej, stmts
def deserializeReply(self, md, replyexpr, side, errfn, errfnSentinel, actor=None):
stmts = [
Whitespace.NL,
self.logMessage(md, replyexpr, "Received reply ", actor, receiving=True),
]
if 0 == len(md.returns):
return stmts
def tempvar(r):
return ExprVar(r.var().name + "__reply")
readervar = ExprVar("reader__")
stmts.extend(
[
Whitespace.NL,
StmtDecl(
Decl(Type("IPC::MessageReader"), readervar.name),
initargs=[ExprDeref(self.replyvar), ExprVar.THIS],
),
]
+ [Whitespace.NL]
+ [
_ParamTraits.checkedRead(
r.ipdltype,
r.bareType(side),
tempvar(r),
ExprAddrOf(readervar),
errfn,
"'%s'" % r.ipdltype.name(),
sentinelKey=r.name,
errfnSentinel=errfnSentinel,
)
for r in md.returns
]
# Move-assign the values out of the variables created with
# checkedRead into outparams.
+ [
StmtExpr(ExprAssn(ExprDeref(r.var()), ExprMove(tempvar(r))))
for r in md.returns
]
+ [StmtCode("${reader}.EndRead();", reader=readervar)]
)
return stmts
def sendAsync(self, md, msgexpr, actor=None):
sendok = ExprVar("sendok__")
resolvefn = ExprVar("aResolve")
rejectfn = ExprVar("aReject")
stmts = [
Whitespace.NL,
self.logMessage(md, msgexpr, "Sending ", actor),
self.profilerLabel(md),
]
stmts.append(Whitespace.NL)
# Generate the actual call expression.
send = ExprVar("ChannelSend")
if actor is not None:
send = ExprSelect(actor, "->", send.name)
if md.returns:
stmts.append(
StmtExpr(
ExprCall(
send,
args=[
ExprMove(msgexpr),
ExprVar(md.pqReplyId()),
ExprMove(resolvefn),
ExprMove(rejectfn),
],
)
)
)
retvar = None
else:
stmts.append(
StmtDecl(
Decl(Type.BOOL, sendok.name),
init=ExprCall(send, args=[ExprMove(msgexpr)]),
)
)
retvar = sendok
return (retvar, stmts)
def sendBlocking(self, md, msgexpr, replyexpr, actor=None):
send = ExprVar("ChannelSend")
if actor is not None:
send = ExprSelect(actor, "->", send.name)
sendok = ExprVar("sendok__")
self.externalIncludes.add("mozilla/ProfilerMarkers.h")
return (
sendok,
(
[
Whitespace.NL,
self.logMessage(md, msgexpr, "Sending ", actor),
self.profilerLabel(md),
]
+ [
Whitespace.NL,
StmtDecl(Decl(Type.BOOL, sendok.name), init=ExprLiteral.FALSE),
StmtBlock(
[
StmtExpr(
ExprCall(
ExprVar("AUTO_PROFILER_TRACING_MARKER"),
[
ExprLiteral.String("Sync IPC"),
ExprLiteral.String(
self.protocol.name
+ "::"
+ md.prettyMsgName()
),
ExprVar("IPC"),
],
)
),
StmtExpr(
ExprAssn(
sendok,
ExprCall(
send,
args=[ExprMove(msgexpr), ExprAddrOf(replyexpr)],
),
)
),
]
),
]
),
)
def sendAsyncWithPromise(self, md):
# Create a new promise, and forward to the callback send overload.
promise = _makePromise(md.returns, self.side, resolver=True)
if len(md.returns) > 1:
resolvetype = _tuple([d.bareType(self.side) for d in md.returns])
else:
resolvetype = md.returns[0].bareType(self.side)
resolve = ExprCode(
"""
[promise__](${resolvetype}&& aValue) {
promise__->Resolve(std::move(aValue), __func__);
}
""",
resolvetype=resolvetype,
)
reject = ExprCode(
"""
[promise__](ResponseRejectReason&& aReason) {
promise__->Reject(std::move(aReason), __func__);
}
""",
resolvetype=resolvetype,
)
args = [ExprMove(p.var()) for p in md.params] + [resolve, reject]
stmt = StmtCode(
"""
RefPtr<${promise}> promise__ = new ${promise}(__func__);
promise__->UseDirectTaskDispatch(__func__);
${send}($,{args});
return promise__;
""",
promise=promise,
send=md.sendMethod(),
args=args,
)
return [stmt]
def callAllocActor(self, md, retsems, side):
actortype = md.actorDecl().bareType(self.side)
if md.decl.type.constructedType().isRefcounted():
actortype.ptr = False
actortype = _refptr(actortype)
callalloc = self.thisCall(
_allocMethod(md.decl.type.constructedType(), side),
args=md.makeCxxArgs(retsems=retsems, retcallsems="out", implicit=False),
)
return StmtDecl(Decl(actortype, md.actorDecl().var().name), init=callalloc)
def invokeRecvHandler(self, md):
retsems = "in"
if md.decl.type.isAsync() and md.returns:
retsems = "resolver"
okdecl = StmtDecl(
Decl(Type("mozilla::ipc::IPCResult"), "__ok"),
init=self.thisCall(
md.recvMethod(),
md.makeCxxArgs(
paramsems="move",
retsems=retsems,
retcallsems="out",
),
),
)
failif = StmtIf(ExprNot(ExprVar("__ok")))
failif.addifstmts(
[
_protocolErrorBreakpoint("Handler returned error code!"),
Whitespace(
"// Error handled in mozilla::ipc::IPCResult\n", indent=True
),
StmtReturn(_Result.ProcessingError),
]
)
return [okdecl, failif]
def makeDtorMethodDecl(self, md, actorvar):
decl = self.makeSendMethodDecl(md)
decl.params.insert(
0,
Decl(
_cxxInType(
ipdl.type.ActorType(md.decl.type.constructedType()),
side=self.side,
direction="send",
),
actorvar.name,
),
)
decl.methodspec = MethodSpec.STATIC
return decl
def makeSendMethodDecl(self, md, promise=False, paramsems="in"):
implicit = md.decl.type.hasImplicitActorParam()
if md.decl.type.isAsync() and md.returns:
if promise:
returnsems = "promise"
rettype = _refptr(Type(md.promiseName()))
else:
returnsems = "callback"
rettype = Type.VOID
else:
assert not promise
returnsems = "out"
rettype = Type.BOOL
decl = MethodDecl(
md.sendMethod(),
params=md.makeCxxParams(
paramsems,
returnsems=returnsems,
side=self.side,
implicit=implicit,
direction="send",
),
warn_unused=(
(self.side == "parent" and returnsems != "callback")
or (md.decl.type.isCtor() and not md.decl.type.isAsync())
),
ret=rettype,
)
if md.decl.type.isCtor():
decl.ret = md.actorDecl().bareType(self.side)
return decl
def logMessage(self, md, msgptr, pfx, actor=None, receiving=False):
actorname = _actorName(self.protocol.name, self.side)
return StmtCode(
"""
if (mozilla::ipc::LoggingEnabledFor(${protocolname}, ${side})) {
mozilla::ipc::LogMessageForProtocol(
${actorname},
${actor}->ToplevelProtocol()->OtherPidMaybeInvalid(),
${pfx},
${msgptr}->type(),
mozilla::ipc::MessageDirection::${direction});
}
""",
protocolname=ExprLiteral.String(self.protocol.name),
side=_cxxSide(self.side),
actorname=ExprLiteral.String(actorname),
actor=actor or ExprVar.THIS,
pfx=ExprLiteral.String(pfx),
msgptr=msgptr,
direction="eReceiving" if receiving else "eSending",
)
def profilerLabel(self, md):
self.externalIncludes.add("mozilla/ProfilerLabels.h")
return StmtCode(
"""
AUTO_PROFILER_LABEL("${name}::${msgname}", OTHER);
""",
name=self.protocol.name,
msgname=md.prettyMsgName(),
)
def saveActorId(self, md):
idvar = ExprVar("id__")
if md.decl.type.hasReply():
# only save the ID if we're actually going to use it, to
# avoid unused-variable warnings
saveIdStmts = [
StmtDecl(Decl(_actorIdType(), idvar.name), self.protocol.routingId())
]
else:
saveIdStmts = []
return idvar, saveIdStmts
class _GenerateProtocolParentCode(_GenerateProtocolActorCode):
def __init__(self):
_GenerateProtocolActorCode.__init__(self, "parent")
def sendsMessage(self, md):
return not md.decl.type.isIn()
def receivesMessage(self, md):
return md.decl.type.isInout() or md.decl.type.isIn()
class _GenerateProtocolChildCode(_GenerateProtocolActorCode):
def __init__(self):
_GenerateProtocolActorCode.__init__(self, "child")
def sendsMessage(self, md):
return not md.decl.type.isOut()
def receivesMessage(self, md):
return md.decl.type.isInout() or md.decl.type.isOut()
# -----------------------------------------------------------------------------
# Utility passes
##
def _splitClassDeclDefn(cls):
"""Destructively split |cls| methods into declarations and
definitions (if |not methodDecl.force_inline|). Return classDecl,
methodDefns."""
defns = Block()
for i, stmt in enumerate(cls.stmts):
if isinstance(stmt, MethodDefn) and not stmt.decl.force_inline:
decl, defn = _splitMethodDeclDefn(stmt, cls)
cls.stmts[i] = StmtDecl(decl)
if defn:
defns.addstmts([defn, Whitespace.NL])
return cls, defns
def _splitMethodDeclDefn(md, cls):
# Pure methods have decls but no defns.
if md.decl.methodspec == MethodSpec.PURE:
return md.decl, None
saveddecl = deepcopy(md.decl)
md.decl.cls = cls
# Don't emit method specifiers on method defns.
md.decl.methodspec = MethodSpec.NONE
md.decl.warn_unused = False
md.decl.only_for_definition = True
for param in md.decl.params:
if isinstance(param, Param):
param.default = None
return saveddecl, md
def _splitFuncDeclDefn(fun):
assert not fun.decl.force_inline
return StmtDecl(fun.decl), fun